Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
8fc15e04
Commit
8fc15e04
authored
Apr 18, 2025
by
gaoqiong
Browse files
deepseek_v3/r1 int8 量化首字调优
parent
f5f9f42f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
73 additions
and
115 deletions
+73
-115
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+35
-40
vllm/model_executor/layers/quantization/utils/int8_utils.py
vllm/model_executor/layers/quantization/utils/int8_utils.py
+38
-75
No files found.
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
8fc15e04
...
@@ -42,7 +42,10 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
...
@@ -42,7 +42,10 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#256
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
8
}
]
]
stage2_best_config
=
[
stage2_best_config
=
[
...
@@ -62,7 +65,11 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
...
@@ -62,7 +65,11 @@ if device_name=='K100_AI' and torch.cuda.get_device_properties(torch.cuda.curren
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#14
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#15
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#16
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
}
,
#256
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
# 8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"kpack"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
}
]
]
else
:
else
:
stage1_best_config
=
[
stage1_best_config
=
[
...
@@ -83,7 +90,10 @@ else:
...
@@ -83,7 +90,10 @@ else:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#32
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#256
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
#8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
8
},
]
]
stage2_best_config
=
[
stage2_best_config
=
[
...
@@ -103,7 +113,11 @@ else:
...
@@ -103,7 +113,11 @@ else:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#13
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#14
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#15
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#16
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"num_stages"
:
0
,
"num_warps"
:
2
},
#32
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#256
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#1024
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
},
#8192
{
"BLOCK_SIZE_M"
:
32
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
,
"kpack"
:
2
,
"num_stages"
:
0
,
"num_warps"
:
4
}
]
]
@
triton
.
jit
@
triton
.
jit
...
@@ -1662,9 +1676,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1662,9 +1676,10 @@ def fused_experts_impl(hidden_states: torch.Tensor,
# so the cache size and config are already set correctly and
# so the cache size and config are already set correctly and
# do not need to be adjusted.
# do not need to be adjusted.
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache1
=
intermediate_cache1
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
]
intermediate_cache2
=
intermediate_cache2
[:
tokens_in_chunk
*
topk_ids
.
shape
[
1
]
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
intermediate_cache3
=
intermediate_cache3
[:
tokens_in_chunk
]
config
=
get_config_func
(
tokens_in_chunk
)
if
not
use_int8_w8a8
:
config
=
get_config_func
(
tokens_in_chunk
)
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_ids
=
topk_ids
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
curr_topk_weights
=
topk_weights
[
begin_chunk_idx
:
end_chunk_idx
]
...
@@ -1677,24 +1692,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1677,24 +1692,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
=
stage1_best_config
[
15
]
config
=
stage1_best_config
[
15
]
elif
m
<=
64
:
elif
m
<=
64
:
config
=
stage1_best_config
[
16
]
config
=
stage1_best_config
[
16
]
elif
m
<
256
:
elif
m
<=
256
:
config
=
{
config
=
stage1_best_config
[
17
]
"BLOCK_SIZE_M"
:
16
,
elif
m
<=
1024
:
"BLOCK_SIZE_N"
:
32
,
config
=
stage1_best_config
[
18
]
"BLOCK_SIZE_K"
:
64
,
elif
m
<=
8192
:
"GROUP_SIZE_M"
:
1
,
config
=
stage1_best_config
[
19
]
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
else
:
config
=
{
config
=
stage1_best_config
[
20
]
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
if
moe_ep_size
==
1
:
if
moe_ep_size
==
1
:
if
use_int4_w4a16
:
if
use_int4_w4a16
:
...
@@ -1740,24 +1745,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
...
@@ -1740,24 +1745,14 @@ def fused_experts_impl(hidden_states: torch.Tensor,
config
=
stage2_best_config
[
15
]
config
=
stage2_best_config
[
15
]
elif
m
<=
64
:
elif
m
<=
64
:
config
=
stage2_best_config
[
16
]
config
=
stage2_best_config
[
16
]
elif
m
<
256
:
elif
m
<=
256
:
config
=
{
config
=
stage2_best_config
[
17
]
"BLOCK_SIZE_M"
:
16
,
elif
m
<=
1024
:
"BLOCK_SIZE_N"
:
32
,
config
=
stage2_best_config
[
18
]
"BLOCK_SIZE_K"
:
64
,
elif
m
<=
8192
:
"GROUP_SIZE_M"
:
1
,
config
=
stage2_best_config
[
19
]
"num_stages"
:
0
,
"num_warps"
:
4
}
else
:
else
:
config
=
{
config
=
stage2_best_config
[
20
]
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
32
,
"GROUP_SIZE_M"
:
8
,
"num_stages"
:
0
,
"num_warps"
:
4
}
invoke_fused_moe_kernel
(
intermediate_cache2
,
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
w2
,
...
...
vllm/model_executor/layers/quantization/utils/int8_utils.py
View file @
8fc15e04
...
@@ -68,7 +68,6 @@ def per_token_quant_int8(x):
...
@@ -68,7 +68,6 @@ def per_token_quant_int8(x):
return
x_q
,
scales
return
x_q
,
scales
@
triton
.
jit
@
triton
.
jit
def
_per_token_group_quant_int8
(
def
_per_token_group_quant_int8
(
# Pointers to inputs and output
# Pointers to inputs and output
...
@@ -76,9 +75,12 @@ def _per_token_group_quant_int8(
...
@@ -76,9 +75,12 @@ def _per_token_group_quant_int8(
y_q_ptr
,
y_q_ptr
,
y_s_ptr
,
y_s_ptr
,
# Stride of input
# Stride of input
y_stride
,
group_size
,
# Collums of input
# M,
N
,
# K,
# # Collums of input
# N,
SIZE
,
# Avoid to divide zero
# Avoid to divide zero
eps
,
eps
,
# Information for int8
# Information for int8
...
@@ -86,6 +88,7 @@ def _per_token_group_quant_int8(
...
@@ -86,6 +88,7 @@ def _per_token_group_quant_int8(
int8_max
,
int8_max
,
# Meta-parameters
# Meta-parameters
BLOCK
:
tl
.
constexpr
,
BLOCK
:
tl
.
constexpr
,
s_num
:
tl
.
constexpr
,
):
):
"""A Triton-accelerated function to perform
"""A Triton-accelerated function to perform
per-token-group quantization on a tensor.
per-token-group quantization on a tensor.
...
@@ -93,21 +96,26 @@ def _per_token_group_quant_int8(
...
@@ -93,21 +96,26 @@ def _per_token_group_quant_int8(
"""
"""
# Map the program id to the row of X and Y it should compute.
# Map the program id to the row of X and Y it should compute.
g_id
=
tl
.
program_id
(
0
)
g_id
=
tl
.
program_id
(
0
)
y_ptr
+=
g_id
*
y_stride
y_ptr
+=
g_id
*
BLOCK
y_q_ptr
+=
g_id
*
y_stride
y_q_ptr
+=
g_id
*
BLOCK
y_s_ptr
+=
g_id
y_s_ptr
+=
g_id
*
s_num
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
cols
=
tl
.
arange
(
0
,
BLOCK
)
# N <= BLOCK
mask
=
cols
<
N
s_cols
=
tl
.
arange
(
0
,
s_num
)
mask
=
g_id
*
BLOCK
+
cols
<
SIZE
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
load
(
y_ptr
+
cols
,
mask
=
mask
,
other
=
0.0
).
to
(
tl
.
float32
)
y
=
tl
.
reshape
(
y
,
(
s_num
,
128
))
# Quant
# Quant
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)),
eps
)
_absmax
=
tl
.
maximum
(
tl
.
max
(
tl
.
abs
(
y
)
,
axis
=
1
),
eps
)
y_s
=
_absmax
/
int8_max
y_s
=
(
_absmax
/
int8_max
).
reshape
(
s_num
,
1
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
clamp
(
y
/
y_s
,
int8_min
,
int8_max
).
to
(
y_q_ptr
.
dtype
.
element_ty
)
y_q
=
tl
.
reshape
(
y_q
,
(
s_num
*
128
))
y_s
=
tl
.
reshape
(
y_s
,
(
s_num
))
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_q_ptr
+
cols
,
y_q
,
mask
=
mask
)
tl
.
store
(
y_s_ptr
,
y_s
)
tl
.
store
(
y_s_ptr
+
s_cols
,
y_s
.
to
(
y_s_ptr
.
dtype
.
element_ty
)
)
def
per_token_group_quant_int8
(
def
per_token_group_quant_int8
(
...
@@ -139,30 +147,38 @@ def per_token_group_quant_int8(
...
@@ -139,30 +147,38 @@ def per_token_group_quant_int8(
int8_min
=
iinfo
.
min
int8_min
=
iinfo
.
min
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
x_q
=
torch
.
empty_like
(
x
,
device
=
x
.
device
,
dtype
=
dtype
)
M
=
x
.
numel
()
//
group_size
N
=
group_size
N
=
group_size
m
=
x
.
shape
[
0
]
if
m
<=
16
:
config
=
{
"BLOCK"
:
128
,
"s_num"
:
1
,
"num_warps"
:
1
,
"num_stages"
:
1
}
elif
m
<=
256
:
config
=
{
"BLOCK"
:
1024
,
"s_num"
:
8
,
"num_warps"
:
4
,
"num_stages"
:
1
}
else
:
config
=
{
"BLOCK"
:
2048
,
"s_num"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
2
}
grid
=
lambda
META
:
(
triton
.
cdiv
(
x
.
numel
(),
META
[
'BLOCK'
]),
)
x_s
=
torch
.
empty
(
x_s
=
torch
.
empty
(
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
x
.
shape
[:
-
1
]
+
(
x
.
shape
[
-
1
]
//
group_size
,),
device
=
x
.
device
,
device
=
x
.
device
,
dtype
=
torch
.
float32
,
dtype
=
torch
.
float32
,
)
)
BLOCK
=
triton
.
next_power_of_2
(
N
)
_per_token_group_quant_int8
[
grid
](
# heuristics for number of warps
num_warps
=
min
(
max
(
BLOCK
//
256
,
1
),
8
)
num_stages
=
1
_per_token_group_quant_int8
[(
M
,)](
x
,
x
,
x_q
,
x_q
,
x_s
,
x_s
,
group_size
,
group_size
,
N
,
# M,
# K,
# N,
x
.
numel
(),
eps
,
eps
,
int8_min
=
int8_min
,
int8_min
=
int8_min
,
int8_max
=
int8_max
,
int8_max
=
int8_max
,
BLOCK
=
BLOCK
,
**
config
num_warps
=
num_warps
,
num_stages
=
num_stages
,
)
)
return
x_q
,
x_s
return
x_q
,
x_s
...
@@ -458,59 +474,6 @@ def w8a8_block_int8_matmul(
...
@@ -458,59 +474,6 @@ def w8a8_block_int8_matmul(
return
C
return
C
def
native_w8a8_block_int8_matmul
(
A
,
B
,
As
,
Bs
,
block_size
,
output_dtype
=
torch
.
float16
):
"""matrix multiplication with block-wise quantization using native torch.
It takes two input tensors `A` and `B` with scales `As` and `Bs`.
The output is returned in the specified `output_dtype`.
"""
A
=
A
.
to
(
torch
.
float32
)
B
=
B
.
to
(
torch
.
float32
)
assert
A
.
shape
[
-
1
]
==
B
.
shape
[
-
1
]
assert
B
.
ndim
==
2
and
B
.
is_contiguous
()
and
Bs
.
ndim
==
2
assert
len
(
block_size
)
==
2
block_n
,
block_k
=
block_size
[
0
],
block_size
[
1
]
assert
(
A
.
shape
[
-
1
]
+
block_k
-
1
)
//
block_k
==
As
.
shape
[
-
1
]
assert
A
.
shape
[:
-
1
]
==
As
.
shape
[:
-
1
]
M
=
A
.
numel
()
//
A
.
shape
[
-
1
]
N
,
K
=
B
.
shape
origin_C_shape
=
A
.
shape
[:
-
1
]
+
(
N
,)
A
=
A
.
reshape
(
M
,
A
.
shape
[
-
1
])
As
=
As
.
reshape
(
M
,
As
.
shape
[
-
1
])
n_tiles
=
(
N
+
block_n
-
1
)
//
block_n
k_tiles
=
(
K
+
block_k
-
1
)
//
block_k
assert
n_tiles
==
Bs
.
shape
[
0
]
assert
k_tiles
==
Bs
.
shape
[
1
]
C_shape
=
(
M
,
N
)
C
=
torch
.
zeros
(
C_shape
,
dtype
=
torch
.
float32
,
device
=
A
.
device
)
A_tiles
=
[
A
[:,
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
)]
for
i
in
range
(
k_tiles
)]
B_tiles
=
[
[
B
[
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
),
i
*
block_k
:
min
((
i
+
1
)
*
block_k
,
K
),
]
for
i
in
range
(
k_tiles
)
]
for
j
in
range
(
n_tiles
)
]
C_tiles
=
[
C
[:,
j
*
block_n
:
min
((
j
+
1
)
*
block_n
,
N
)]
for
j
in
range
(
n_tiles
)]
As_tiles
=
[
As
[:,
i
:
i
+
1
]
for
i
in
range
(
k_tiles
)]
for
i
in
range
(
k_tiles
):
for
j
in
range
(
n_tiles
):
a
=
A_tiles
[
i
]
b
=
B_tiles
[
j
][
i
]
c
=
C_tiles
[
j
]
s
=
As_tiles
[
i
]
*
Bs
[
j
][
i
]
c
[:,
:]
+=
torch
.
matmul
(
a
,
b
.
t
())
*
s
C
=
C
.
reshape
(
origin_C_shape
).
to
(
output_dtype
)
return
C
def
apply_w8a8_block_int8_linear
(
def
apply_w8a8_block_int8_linear
(
input
:
torch
.
Tensor
,
input
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment