Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
777d5f95
Commit
777d5f95
authored
Apr 18, 2025
by
zhuwenwen
Browse files
deepseek_v3/r1 int8 量化首字调优
parent
f587d8f7
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
111 deletions
+73
-111
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+34
-39
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+39
-72
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
777d5f95
...
@@ -49,7 +49,10 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
...
@@ -49,7 +49,10 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#256
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
}
]
]
stage2_best_config
=
[
stage2_best_config
=
[
...
@@ -69,7 +72,11 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
...
@@ -69,7 +72,11 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#16
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
}
,
#256
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
# 8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
]
]
else
:
else
:
stage1_best_config
=
[
stage1_best_config
=
[
...
@@ -90,7 +97,10 @@ else:
...
@@ -90,7 +97,10 @@ else:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#256
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
]
]
stage2_best_config
=
[
stage2_best_config
=
[
...
@@ -110,7 +120,11 @@ else:
...
@@ -110,7 +120,11 @@ else:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#16
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#32
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#256
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
}
]
]
@
triton
.
jit
@
triton
.
jit
...
@@ -1644,6 +1658,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1644,6 +1658,7 @@ def fused_experts_impl(hidden_states: torch.Tensor,
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
topk_ids
.
shape
[
1
]]
topk_ids
.
shape
[
1
]]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
if
not
use_int8_w8a8
:
config
=
get_config_func
(
tokens_in_chunk
)
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
...
@@ -1657,24 +1672,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1657,24 +1672,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
=
stage1_best_config
[
15
]
config
=
stage1_best_config
[
15
]
elif
m
<=
64
:
elif
m
<=
64
:
config
=
stage1_best_config
[
16
]
config
=
stage1_best_config
[
16
]
elif
m
<
256
:
elif
m
<=
256
:
config
=
{
config
=
stage1_best_config
[
17
]
"BLOCK_SIZE_M"
:
16
,
elif
m
<=
1024
:
"BLOCK_SIZE_N"
:
32
,
config
=
stage1_best_config
[
18
]
"BLOCK_SIZE_K"
:
64
,
elif
m
<=
8192
:
"GROUP_SIZE_M"
:
1
,
config
=
stage1_best_config
[
19
]
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
else
:
config
=
{
config
=
stage1_best_config
[
20
]
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
qcurr_hidden_states
,
qa1_scale
=
moe_kernel_prepare_input
(
A
=
curr_hidden_states
,
A
=
curr_hidden_states
,
...
@@ -1748,24 +1753,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1748,24 +1753,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
=
stage2_best_config
[
15
]
config
=
stage2_best_config
[
15
]
elif
m
<=
64
:
elif
m
<=
64
:
config
=
stage2_best_config
[
16
]
config
=
stage2_best_config
[
16
]
elif
m
<
256
:
elif
m
<=
256
:
config
=
{
config
=
stage2_best_config
[
17
]
"BLOCK_SIZE_M"
:
16
,
elif
m
<=
1024
:
"BLOCK_SIZE_N"
:
32
,
config
=
stage2_best_config
[
18
]
"BLOCK_SIZE_K"
:
64
,
elif
m
<=
8192
:
"GROUP_SIZE_M"
:
1
,
config
=
stage2_best_config
[
19
]
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
else
:
config
=
{
config
=
stage2_best_config
[
20
]
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
invoke_fused_moe_kernel
(
qintermediate_cache2
,
invoke_fused_moe_kernel
(
qintermediate_cache2
,
w2
,
w2
,
...
...
vllm/model_executor/layers/quantization/utils/int8_utils.py
View file @
777d5f95
...
@@ -149,9 +149,12 @@ def _per_token_group_quant_int8(
...
@@ -149,9 +149,12 @@ def _per_token_group_quant_int8(
y_q_ptr
,
y_q_ptr
,
y_s_ptr
,
y_s_ptr
,
# Stride of input
# Stride of input
y_stride
,
group_size
,
# Columns of input
# M,
N
,
# K,
# # Collums of input
# N,
SIZE
,
# Avoid to divide zero
# Avoid to divide zero
eps
,
eps
,
# Information for int8
# Information for int8
...
@@ -159,6 +162,7 @@ def _per_token_group_quant_int8(
...
@@ -159,6 +162,7 @@ def _per_token_group_quant_int8(
int8_max
,
int8_max
,
# Meta-parameters
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
s_num
:
tl
.
constexpr
,
):
):
"""A Triton-accelerated function to perform per-token-group
"""A Triton-accelerated function to perform per-token-group
quantization on a tensor.
quantization on a tensor.
...
@@ -167,21 +171,26 @@ def _per_token_group_quant_int8(
...
@@ -167,21 +171,26 @@ def _per_token_group_quant_int8(
"""
"""
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_ptr
+=
g_id
*
BLOCK
y_q_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
BLOCK
y_s_ptr
+=
g_id
y_s_ptr
+=
g_id
*
s_num
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
s_cols
=
tl
.
arange
(
0
,
s_num
)
mask
=
g_id
*
BLOCK
+
cols
<
SIZE
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
reshape
(
y
,
(
s_num
,
128
))
# Quant
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)
,
axis
=
1
),
eps
)
y_s
=
_absmax
/
int8_max
y_s
=
(
_absmax
/
int8_max
).
reshape
(
s_num
,
1
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
reshape
(
y_q
,
(
s_num
*
128
))
y_s
=
tl
.
reshape
(
y_s
,
(
s_num
))
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
+
s_cols
,
y_s
.
to
(
y_s_ptr
.
dtype
.
element_ty
)
)
def
per_token_group_quant_int8
(
def
per_token_group_quant_int8
(
...
@@ -215,8 +224,20 @@ def per_token_group_quant_int8(
...
@@ -215,8 +224,20 @@ def per_token_group_quant_int8(
int8_min
=
iinfo
.
min
int8_min
=
iinfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
N
=
group_size
m
=
x
.
shape
[
0
]
if
m
<=
16
:
config
=
{
"BLOCK"
:
128
,
"s_num"
:
1
,
"num_warps"
:
1
,
"num_stages"
:
1
}
elif
m
<=
256
:
config
=
{
"BLOCK"
:
1024
,
"s_num"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
1
}
else
:
config
=
{
"BLOCK"
:
2048
,
"s_num"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
}
grid
=
lambda
META
:
(
triton
.
cdiv
(
x
.
numel
(),
META
[
'BLOCK'
]),
)
x_s
=
torch
.
empty
(
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
),
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,
),
device
=
x
.
device
,
device
=
x
.
device
,
...
@@ -227,18 +248,19 @@ def per_token_group_quant_int8(
...
@@ -227,18 +248,19 @@ def per_token_group_quant_int8(
# heuristics for number of warps
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
num_stages
=
1
_per_token_group_quant_int8
[
(
M
,
)
](
_per_token_group_quant_int8
[
grid
](
x
,
x
,
x_q
,
x_q
,
x_s
,
x_s
,
group_size
,
group_size
,
N
,
# M,
# K,
# N,
x
.
numel
(),
eps
,
eps
,
int8_min
=
int8_min
,
int8_min
=
int8_min
,
int8_max
=
int8_max
,
int8_max
=
int8_max
,
BLOCK
=
BLOCK
,
**
config
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
)
return
x_q
,
x_s
return
x_q
,
x_s
...
@@ -534,61 +556,6 @@ def w8a8_block_int8_matmul(
...
@@ -534,61 +556,6 @@ def w8a8_block_int8_matmul(
return
C
return
C
def
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
Bs
.
shape
[
0
]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs
[
j
][
i
]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
apply_w8a8_block_int8_linear
(
def
apply_w8a8_block_int8_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment