Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2d5a25cd
Commit
2d5a25cd
authored
Dec 03, 2024
by
zhuwenwen
Browse files
Merge branch '0.6.2-w8a8' into 'v0.6.2-dev'
0.6.2 w8a8 See merge request dcutoolkit/deeplearing/vllm!43
parents
0dc55ec0
26513bb5
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
420 additions
and
245 deletions
+420
-245
README.md
README.md
+1
-1
tests/kernels/test_cutlass.py
tests/kernels/test_cutlass.py
+240
-229
vllm/_custom_ops.py
vllm/_custom_ops.py
+2
-2
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+28
-3
vllm/model_executor/models/llama.py
vllm/model_executor/models/llama.py
+33
-5
vllm/model_executor/models/qwen.py
vllm/model_executor/models/qwen.py
+30
-5
vllm/utils.py
vllm/utils.py
+86
-0
No files found.
README.md
View file @
2d5a25cd
...
...
@@ -9,7 +9,7 @@ vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention
## 支持模型结构列表
| 结构 | 模型 | FP16/BF16 | AWQ | GPTQ |
| :------: | :------: | :------: | :------: |
| :------: | :------: | :------: | :------: |
:------: |
| LlamaForCausalLM | Llama 3.1,Llama 3,Llama 2,Llama,Yi,Codellama,deepseek | Yes | Yes | Yes |
| QWenLMHeadModel | QWen,Qwen-VL | Yes | Yes | Yes |
| Qwen2ForCausalLM | QWen2,QWen1.5,CodeQwen1.5 | Yes | Yes | Yes |
...
...
tests/kernels/test_cutlass.py
View file @
2d5a25cd
...
...
@@ -9,14 +9,14 @@ import torch
from
tests.kernels.utils
import
opcheck
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
#
from vllm.platforms import current_platform
CUDA_DEVICES
=
[
f
"cuda:
{
i
}
"
for
i
in
range
(
1
if
torch
.
cuda
.
device_count
()
==
1
else
2
)
f
"cuda:
{
0
}
"
#
for i in range(1 if torch.cuda.device_count() == 1 else 2)
]
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
#
capability = current_platform.get_device_capability()
capability
=
90
#
capability[0] * 10 + capability[1]
def
to_fp8
(
tensor
:
torch
.
Tensor
):
...
...
@@ -39,7 +39,7 @@ def baseline_scaled_mm(a: torch.Tensor,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
],
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
output
=
(
scale_a
*
(
scale_b
*
(
torch
.
mm
(
output
=
(
scale_a
*
(
scale_b
.
T
*
(
torch
.
mm
(
a
.
to
(
dtype
=
torch
.
float32
),
b
.
to
(
dtype
=
torch
.
float32
))))).
to
(
out_dtype
)
if
bias
is
not
None
:
output
=
output
+
bias
...
...
@@ -99,7 +99,7 @@ def cutlass_int8_gemm_helper(m: int,
scale_a
=
(
torch
.
randn
((
m_a_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
scale_b
=
(
torch
.
randn
((
1
,
n_b_scales
),
device
=
device
,
scale_b
=
(
torch
.
randn
((
n_b_scales
,
1
),
device
=
device
,
dtype
=
torch
.
float32
))
if
use_bias
:
...
...
@@ -107,42 +107,48 @@ def cutlass_int8_gemm_helper(m: int,
else
:
bias
=
None
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
# print("a.shape:",a.shape)
# print("b.shape:",b.shape)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
baseline
=
baseline_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
# print("out:",out[0:5][0:5])
# print("baseline:",baseline[0:5][0:5])
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm
,
(
out
,
a
,
b
,
scale_a
,
scale_b
,
bias
))
#
opcheck(torch.ops._C.cutlass_scaled_mm,
#
(out, a, b, scale_a, scale_b, bias))
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
100
,
33
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
4096
,
8192
,
16384
,
24576
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
#
@pytest.mark.parametrize("m", [1, 16, 32, 64, 128, 256, 512, 222, 100, 33])
#
@pytest.mark.parametrize("n", [2048, 4096, 8192, 16384, 24576, 256, 1024])
#
@pytest.mark.parametrize("k", [128, 496, 1024])
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm(m: int, n: int, k: int, per_act_token: bool,
#
per_out_ch: bool, use_bias: bool):
#
cutlass_fp8_gemm_helper(m, n, k, per_act_token, per_out_ch, use_bias)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
1
,
16
,
32
,
64
,
128
,
256
,
512
,
222
,
33
,
1
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
2048
,
8192
,
16384
,
256
,
1024
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
128
,
496
,
1024
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm
(
m
:
int
,
n
:
int
,
k
:
int
,
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
cutlass_int8_gemm_helper
(
m
,
n
,
k
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
b
float16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
float16
])
#
torch.
b
float16
,
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
...
...
@@ -156,50 +162,50 @@ def test_cutlass_int8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_output_dtype
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
out_dtype
:
Type
[
torch
.
dtype
],
use_bias
:
bool
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
,
device
:
str
):
cutlass_fp8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
torch
.
bfloat16
,
device
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"device"
,
CUDA_DEVICES
)
def
test_cutlass_int8_gemm_devices
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
,
device
:
str
):
cutlass_int8_gemm_helper
(
512
,
512
,
512
,
per_act_token
,
per_out_ch
,
use_bias
,
out_dtype
=
torch
.
bfloat16
,
device
=
device
)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm_output_dtype(per_act_token: bool, per_out_ch: bool,
#
out_dtype: Type[torch.dtype],
#
use_bias: bool):
#
cutlass_fp8_gemm_helper(512,
#
512,
#
512,
#
per_act_token,
#
per_out_ch,
#
use_bias,
#
out_dtype=out_dtype)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.parametrize("device", CUDA_DEVICES)
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm_devices(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool, device: str):
#
cutlass_fp8_gemm_helper(512, 512, 512, per_act_token, per_out_ch, use_bias,
#
torch.bfloat16, device)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.parametrize("device", CUDA_DEVICES)
#
def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool, device: str):
#
cutlass_int8_gemm_helper(512,
#
512,
#
512,
#
per_act_token,
#
per_out_ch,
#
use_bias,
#
out_dtype=torch.bfloat16,
#
device=device)
# For the following two tests:
...
...
@@ -207,155 +213,155 @@ def test_cutlass_int8_gemm_devices(per_act_token: bool, per_out_ch: bool,
# of a large power of two. In any case, the kernel will have a naive fallback
# when N and K are not divisible by 16. But M is the number of tokens and the
# kernel must handle any M thrown at it.
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
skipif
(
not
current_platform
.
has_device_capability
(
89
),
reason
=
"FP8 is not supported on this GPU type."
)
def
test_cutlass_fp8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_fp8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
def
test_cutlass_int8_gemm_m_sweep
(
per_act_token
:
bool
,
per_out_ch
:
bool
,
use_bias
:
bool
):
for
nk
in
range
(
32
,
128
,
32
):
for
m
in
range
(
1
,
128
):
cutlass_int8_gemm_helper
(
m
,
nk
,
nk
,
per_act_token
,
per_out_ch
,
use_bias
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
skip
def
test_cutlass_int8_azp_bias_fold
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
):
# Currently, the test is failing because folding azp into
# 16-bit bias loses too much precision
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
((
1
,
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
+
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
+
azp_a
)
baseline_dq
=
torch
.
mm
(
a_dq
,
b_dq
).
to
(
out_dtype
)
J
=
torch
.
ones
((
1
,
k
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
azp_bias
=
(
azp_a
*
scale_b
*
(
J
@
bq_f32
)).
to
(
out_dtype
)
assert
azp_bias
.
shape
==
(
1
,
n
)
assert
azp_bias
[
0
,
:].
shape
==
(
n
,
)
baseline_q
=
(
scale_a
.
to
(
device
=
'cpu'
)
*
scale_b
.
to
(
device
=
'cpu'
)
*
(
(
aq_i32
+
azp_aq_i8
).
to
(
device
=
'cpu'
)
@
bq_i32
.
to
(
device
=
'cpu'
))).
to
(
dtype
=
out_dtype
,
device
=
'cuda'
)
out
=
ops
.
cutlass_scaled_mm
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
=
out_dtype
,
bias
=
azp_bias
[
0
,
:])
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
1e-2
,
atol
=
1e0
)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
1e-2
,
atol
=
1e0
)
@
pytest
.
mark
.
parametrize
(
"m"
,
[
32
,
64
,
128
])
@
pytest
.
mark
.
parametrize
(
"n"
,
[
16
,
32
,
64
])
@
pytest
.
mark
.
parametrize
(
"k"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"out_dtype"
,
[
torch
.
bfloat16
,
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"use_bias"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"azp_per_token"
,
[
True
,
False
])
def
test_cutlass_int8_azp
(
m
:
int
,
n
:
int
,
k
:
int
,
out_dtype
:
torch
.
dtype
,
use_bias
:
bool
,
azp_per_token
:
bool
):
m_azp
=
m
if
azp_per_token
else
1
scale_a
=
torch
.
randn
((
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
n
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
aq_i8
=
rand_int8
((
m
,
k
))
aq_i32
=
aq_i8
.
to
(
dtype
=
torch
.
int32
)
aq_f32
=
aq_i8
.
to
(
dtype
=
torch
.
float32
)
bq_i8
=
rand_int8
((
n
,
k
)).
t
()
bq_i32
=
bq_i8
.
to
(
dtype
=
torch
.
int32
)
bq_f32
=
bq_i8
.
to
(
dtype
=
torch
.
float32
)
b_dq
=
scale_b
*
bq_f32
azp_a
=
torch
.
rand
(
(
m_azp
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
*
10
+
1.5
azp_aq_i8
=
(
azp_a
/
scale_a
).
to
(
dtype
=
torch
.
int8
)
azp_a
=
azp_aq_i8
.
to
(
dtype
=
torch
.
float32
)
*
scale_a
# correct for rounding
a_dq
=
scale_a
*
(
aq_i32
-
azp_aq_i8
).
to
(
dtype
=
torch
.
float32
)
torch
.
testing
.
assert_close
(
a_dq
,
scale_a
*
aq_f32
-
azp_a
,
rtol
=
1e-4
,
atol
=
1e-3
)
if
use_bias
:
bias
=
torch
.
rand
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
*
10
+
2.5
else
:
bias
=
torch
.
zeros
((
1
,
n
),
device
=
"cuda"
,
dtype
=
out_dtype
)
baseline_dq
=
(
torch
.
mm
(
a_dq
,
b_dq
)
+
bias
).
to
(
out_dtype
)
# int32 mm not supported on CUDA
a_noazp_i32_cpu
=
(
aq_i32
-
azp_aq_i8
).
to
(
device
=
'cpu'
)
cq
=
(
a_noazp_i32_cpu
@
bq_i32
.
to
(
device
=
'cpu'
)).
to
(
device
=
'cuda'
)
baseline_q
=
(
scale_a
*
scale_b
*
cq
+
bias
).
to
(
dtype
=
out_dtype
)
# Hadamard is just the sum of the cols
azp_adj_i32
=
bq_i32
.
sum
(
dim
=
0
,
keepdim
=
True
,
dtype
=
torch
.
int32
)
azp_i32
=
azp_aq_i8
.
to
(
dtype
=
torch
.
int32
)
func_bias
=
bias
if
use_bias
else
None
if
azp_per_token
:
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_adj_i32
,
azp_i32
,
func_bias
)
else
:
azp_with_adj_i32
=
azp_i32
*
azp_adj_i32
out
=
ops
.
cutlass_scaled_mm_azp
(
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
out_dtype
,
azp_with_adj_i32
,
None
,
func_bias
)
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
rtol
=
1e-2
if
out_dtype
==
torch
.
bfloat16
else
1e-3
atol
=
1e-3
torch
.
testing
.
assert_close
(
out
,
baseline_dq
,
rtol
=
rtol
,
atol
=
atol
)
torch
.
testing
.
assert_close
(
out
,
baseline_q
,
rtol
=
rtol
,
atol
=
atol
)
if
azp_per_token
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_adj_i32
,
azp_i32
,
func_bias
))
else
:
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_azp
,
(
out
,
aq_i8
,
bq_i8
,
scale_a
,
scale_b
,
azp_with_adj_i32
,
None
,
func_bias
))
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.skipif(not current_platform.has_device_capability(89),
#
reason="FP8 is not supported on this GPU type.")
#
def test_cutlass_fp8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool):
#
for nk in range(32, 128, 32):
#
for m in range(1, 128):
#
cutlass_fp8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
#
use_bias)
#
@pytest.mark.parametrize("per_act_token", [True, False])
#
@pytest.mark.parametrize("per_out_ch", [True, False])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
def test_cutlass_int8_gemm_m_sweep(per_act_token: bool, per_out_ch: bool,
#
use_bias: bool):
#
for nk in range(32, 128, 32):
#
for m in range(1, 128):
#
cutlass_int8_gemm_helper(m, nk, nk, per_act_token, per_out_ch,
#
use_bias)
#
@pytest.mark.parametrize("m", [32, 64, 128])
#
@pytest.mark.parametrize("n", [16, 32, 64])
#
@pytest.mark.parametrize("k", [64, 128, 256])
#
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
#
@pytest.mark.skip
#
def test_cutlass_int8_azp_bias_fold(m: int, n: int, k: int,
#
out_dtype: torch.dtype):
#
# Currently, the test is failing because folding azp into
#
# 16-bit bias loses too much precision
#
scale_a = torch.randn((1, 1), device="cuda", dtype=torch.float32) / 10
#
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
#
aq_i8 = rand_int8((m, k))
#
bq_i8 = rand_int8((n, k)).t()
#
aq_i32 = aq_i8.to(dtype=torch.int32)
#
bq_i32 = bq_i8.to(dtype=torch.int32)
#
aq_f32 = aq_i8.to(dtype=torch.float32)
#
bq_f32 = bq_i8.to(dtype=torch.float32)
#
b_dq = scale_b * bq_f32
#
azp_a = torch.rand((1, ), device="cuda", dtype=torch.float32) * 10 + 1.5
#
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
#
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
#
a_dq = scale_a * (aq_i32 + azp_aq_i8).to(dtype=torch.float32)
#
torch.testing.assert_close(a_dq, scale_a * aq_f32 + azp_a)
#
baseline_dq = torch.mm(a_dq, b_dq).to(out_dtype)
#
J = torch.ones((1, k), device="cuda", dtype=torch.float32)
#
azp_bias = (azp_a * scale_b * (J @ bq_f32)).to(out_dtype)
#
assert azp_bias.shape == (1, n)
#
assert azp_bias[0, :].shape == (n, )
#
baseline_q = (scale_a.to(device='cpu') * scale_b.to(device='cpu') * (
#
(aq_i32 + azp_aq_i8).to(device='cpu') @ bq_i32.to(device='cpu'))).to(
#
dtype=out_dtype, device='cuda')
#
out = ops.cutlass_scaled_mm(aq_i8,
#
bq_i8,
#
scale_a,
#
scale_b,
#
out_dtype=out_dtype,
#
bias=azp_bias[0, :])
#
torch.testing.assert_close(out, baseline_dq, rtol=1e-2, atol=1e0)
#
torch.testing.assert_close(out, baseline_q, rtol=1e-2, atol=1e0)
#
@pytest.mark.parametrize("m", [32, 64, 128])
#
@pytest.mark.parametrize("n", [16, 32, 64])
#
@pytest.mark.parametrize("k", [64, 128, 256])
#
@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float16])
#
@pytest.mark.parametrize("use_bias", [True, False])
#
@pytest.mark.parametrize("azp_per_token", [True, False])
#
def test_cutlass_int8_azp(m: int, n: int, k: int, out_dtype: torch.dtype,
#
use_bias: bool, azp_per_token: bool):
#
m_azp = m if azp_per_token else 1
#
scale_a = torch.randn((m_azp, 1), device="cuda", dtype=torch.float32) / 10
#
scale_b = torch.randn((1, n), device="cuda", dtype=torch.float32) / 10
#
aq_i8 = rand_int8((m, k))
#
aq_i32 = aq_i8.to(dtype=torch.int32)
#
aq_f32 = aq_i8.to(dtype=torch.float32)
#
bq_i8 = rand_int8((n, k)).t()
#
bq_i32 = bq_i8.to(dtype=torch.int32)
#
bq_f32 = bq_i8.to(dtype=torch.float32)
#
b_dq = scale_b * bq_f32
#
azp_a = torch.rand(
#
(m_azp, 1), device="cuda", dtype=torch.float32) * 10 + 1.5
#
azp_aq_i8 = (azp_a / scale_a).to(dtype=torch.int8)
#
azp_a = azp_aq_i8.to(dtype=torch.float32) * scale_a # correct for rounding
#
a_dq = scale_a * (aq_i32 - azp_aq_i8).to(dtype=torch.float32)
#
torch.testing.assert_close(a_dq,
#
scale_a * aq_f32 - azp_a,
#
rtol=1e-4,
#
atol=1e-3)
#
if use_bias:
#
bias = torch.rand((1, n), device="cuda", dtype=out_dtype) * 10 + 2.5
#
else:
#
bias = torch.zeros((1, n), device="cuda", dtype=out_dtype)
#
baseline_dq = (torch.mm(a_dq, b_dq) + bias).to(out_dtype)
#
# int32 mm not supported on CUDA
#
a_noazp_i32_cpu = (aq_i32 - azp_aq_i8).to(device='cpu')
#
cq = (a_noazp_i32_cpu @ bq_i32.to(device='cpu')).to(device='cuda')
#
baseline_q = (scale_a * scale_b * cq + bias).to(dtype=out_dtype)
#
# Hadamard is just the sum of the cols
#
azp_adj_i32 = bq_i32.sum(dim=0, keepdim=True, dtype=torch.int32)
#
azp_i32 = azp_aq_i8.to(dtype=torch.int32)
#
func_bias = bias if use_bias else None
#
if azp_per_token:
#
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
#
out_dtype, azp_adj_i32, azp_i32,
#
func_bias)
#
else:
#
azp_with_adj_i32 = azp_i32 * azp_adj_i32
#
out = ops.cutlass_scaled_mm_azp(aq_i8, bq_i8, scale_a, scale_b,
#
out_dtype, azp_with_adj_i32, None,
#
func_bias)
#
# bfloat16 precision is 7-bit mantissa -> 2^-8 ~ 0.4%
#
# float16 precision is 10-bit mantissa -> 2^-11 ~ 0.05%
#
rtol = 1e-2 if out_dtype == torch.bfloat16 else 1e-3
#
atol = 1e-3
#
torch.testing.assert_close(out, baseline_dq, rtol=rtol, atol=atol)
#
torch.testing.assert_close(out, baseline_q, rtol=rtol, atol=atol)
#
if azp_per_token:
#
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
#
(out, aq_i8, bq_i8, scale_a, scale_b, azp_adj_i32, azp_i32,
#
func_bias))
#
else:
#
opcheck(torch.ops._C.cutlass_scaled_mm_azp,
#
(out, aq_i8, bq_i8, scale_a, scale_b, azp_with_adj_i32, None,
#
func_bias))
# Test working with a subset of A and B
...
...
@@ -367,7 +373,11 @@ def test_cutlass_subset():
whole_b
=
to_int8
(
torch
.
randn
((
big_n
,
big_k
),
device
=
"cuda"
).
t
()
*
5
)
a
=
whole_a
[
0
:
m
,
0
:
k
]
b
=
whole_b
[
0
:
k
,
0
:
n
]
#变成连续内存,矩阵子模块目前不支持计算,需要重新计算lda
a
=
a
.
contiguous
().
reshape
(
m
,
-
1
)
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
scale_a
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
scale_b
=
torch
.
randn
((
1
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
...
...
@@ -399,25 +409,26 @@ class CutlassLayer(torch.nn.Module):
return
ops
.
cutlass_scaled_mm
(
a
,
self
.
b
,
self
.
scale_a
,
self
.
scale_b
,
self
.
out_dtype
)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
,
False
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
,
False
])
#目前只支持per-act-token+per-out-ch(fp16)
@
pytest
.
mark
.
parametrize
(
"per_act_token"
,
[
True
])
@
pytest
.
mark
.
parametrize
(
"per_out_ch"
,
[
True
])
def
test_cutlass_cuda_graph
(
per_act_token
:
bool
,
per_out_ch
:
bool
):
m
,
n
,
k
=
512
,
512
,
512
a
=
to_int8
(
torch
.
randn
((
m
,
k
),
device
=
"cuda"
))
b
=
to_int8
(
torch
.
randn
((
n
,
k
),
device
=
"cuda"
).
t
())
b
=
b
.
contiguous
().
reshape
(
k
,
-
1
)
m_a_scales
=
m
if
per_act_token
else
1
n_b_scales
=
n
if
per_out_ch
else
1
scale_a
=
(
torch
.
randn
(
(
m_a_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
scale_b
=
(
torch
.
randn
(
(
1
,
n_b_scales
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
(
n_b_scales
,
1
),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
/
10
)
# Construct a trivial model with a single layer that calls a CUTLASS kernel
model
=
CutlassLayer
(
b
,
scale_a
,
scale_b
,
torch
.
b
float16
)
model
=
CutlassLayer
(
b
,
scale_a
,
scale_b
,
torch
.
float16
)
# Run the model with a cuda graph
stream
=
torch
.
cuda
.
Stream
()
...
...
@@ -429,9 +440,9 @@ def test_cutlass_cuda_graph(per_act_token: bool, per_out_ch: bool):
g
.
replay
()
baseline
=
torch
.
mm
(
scale_a
*
a
.
to
(
dtype
=
torch
.
float32
),
scale_b
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
bfloat16
)
scale_b
.
T
*
b
.
to
(
dtype
=
torch
.
float32
)).
to
(
torch
.
float16
)
#print("baseline:",baseline)
out
=
ops
.
cutlass_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
torch
.
float16
)
#print("out:",out)
torch
.
testing
.
assert_close
(
out
,
baseline
,
rtol
=
1e-1
,
atol
=
1e0
)
def
test_cutlass_support_opcheck
():
opcheck
(
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
,
(
capability
,
))
vllm/_custom_ops.py
View file @
2d5a25cd
...
...
@@ -12,7 +12,7 @@ from vllm.platforms import current_platform
try
:
from
lmslim
import
quant_ops
except
Exception
:
print
(
"INFO: Please install lmslim if you want to infer gptq or awq model.
\n
"
)
print
(
"INFO: Please install lmslim if you want to infer gptq or awq
or w8a8
model.
\n
"
)
logger
=
init_logger
(
__name__
)
...
...
@@ -706,9 +706,9 @@ def cutlass_scaled_mm(a: torch.Tensor,
# torch.ops._C.cutlass_scaled_mm(out, a, b, scale_a, scale_b, bias)
# return out
#return quant_ops.cutlass_scaled_mm(a, b, scale_a, scale_b, out_dtype, bias)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
2d5a25cd
...
...
@@ -4,12 +4,12 @@ import torch
from
vllm
import
_custom_ops
as
ops
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
W8a8GetCacheJSON
# Input scaling factors are no longer optional in _scaled_mm starting
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
TORCH_DEVICE_IDENTITY
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
W8A8_TRITONJSON
=
W8a8GetCacheJSON
()
def
cutlass_fp8_supported
()
->
bool
:
# cutlass is not supported on Rocm
...
...
@@ -200,12 +200,37 @@ def apply_int8_linear(
x_q
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input
,
input_scale
)
if
w8a8_strategy
==
1
:
m
=
x_q
.
shape
[
0
]
k
=
x_q
.
shape
[
1
]
n
=
weight
.
shape
[
1
]
if
f
"
{
m
}
_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m
}
_
{
n
}
_
{
k
}
"
]
#print("json files:",best_config)
elif
f
"1_
{
n
}
_
{
k
}
"
in
W8A8_TRITONJSON
.
triton_json_dict
[
0
]:
if
m
<
64
:
m_
=
32
elif
m
<
128
:
m_
=
64
elif
m
<
256
:
m_
=
128
elif
m
<
512
:
m_
=
256
elif
m
<
1024
:
m_
=
512
else
:
m_
=
1024
best_config
=
W8A8_TRITONJSON
.
triton_json_dict
[
0
][
f
"
{
m_
}
_
{
n
}
_
{
k
}
"
]
else
:
best_config
=
None
print
(
"config not found!"
)
return
ops
.
triton_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
bias
=
bias
,
best_config
=
best_config
)
elif
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
...
...
vllm/model_executor/models/llama.py
View file @
2d5a25cd
...
...
@@ -51,7 +51,7 @@ from vllm.model_executor.model_loader.weight_utils import (
default_weight_loader
,
kv_cache_scales_loader
,
maybe_remap_kv_scale_name
)
from
vllm.model_executor.sampling_metadata
import
SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
from
vllm.utils
import
is_hip
from
vllm.utils
import
is_hip
,
W8a8GetCacheJSON
from
.interfaces
import
SupportsLoRA
from
.utils
import
PPMissingLayer
,
is_pp_missing_parameter
,
make_layers
...
...
@@ -424,6 +424,8 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
quant_config
,
lora_config
=
lora_config
,
prefix
=
"model"
)
self
.
tritonsingleton
=
W8a8GetCacheJSON
()
if
get_pp_group
().
is_last_rank
:
self
.
unpadded_vocab_size
=
config
.
vocab_size
if
lora_config
:
...
...
@@ -459,6 +461,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
forward
(
self
,
...
...
@@ -648,6 +651,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
qweight_pad
=
torch
.
zeros
(
dim_n
,
int
(
group_size
//
4
),
dtype
=
torch
.
int32
).
cuda
()
qweight
.
data
=
torch
.
cat
((
qweight
.
data
,
qweight_pad
),
dim
=
1
).
contiguous
()
#当为triton支持推理的时候不能进行处理
if
self
.
quant_method
==
"compressed_tensors"
:
lay_key_words
=
[
"self_attn.qkv_proj.weight"
,
...
...
@@ -656,14 +660,38 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
"mlp.down_proj.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
k
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
n
=
weight_data
.
shape
[
0
]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
...
...
vllm/model_executor/models/qwen.py
View file @
2d5a25cd
...
...
@@ -48,7 +48,7 @@ from vllm.multimodal import MULTIMODAL_REGISTRY
from
vllm.multimodal.base
import
MultiModalInputs
from
vllm.multimodal.utils
import
cached_get_tokenizer
from
vllm.sequence
import
IntermediateTensors
,
SequenceData
from
vllm.utils
import
is_list_of
from
vllm.utils
import
is_list_of
,
W8a8GetCacheJSON
from
.utils
import
flatten_bn
,
is_pp_missing_parameter
,
make_layers
from
vllm
import
_custom_ops
as
ops
...
...
@@ -904,6 +904,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self
.
use_gemm_pad
=
os
.
environ
.
get
(
'GEMM_PAD'
)
==
'1'
self
.
use_fa_pad
=
os
.
environ
.
get
(
'FA_PAD'
)
==
'1'
self
.
use_awq_pad
=
os
.
environ
.
get
(
'AWQ_PAD'
)
==
'1'
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
def
_get_image_input_type
(
self
,
...
...
@@ -1100,11 +1101,35 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
"mlp.c_proj.weight"
,
]
combined_words
=
"|"
.
join
(
lay_key_words
)
weight_shapes
=
[]
all_json
=
{}
for
layername
,
weight
in
params_dict
.
items
():
matches
=
re
.
findall
(
combined_words
,
layername
)
if
matches
:
if
matches
and
"scale"
not
in
layername
:
weight_data
=
params_dict
[
layername
]
k
=
weight_data
.
shape
[
0
]
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
k
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
n
=
weight_data
.
shape
[
0
]
#rocblas和cutlass目前都需要weight做处理,但是triton不用
if
self
.
w8a8_strategy
!=
1
:
_weight
=
weight_data
.
T
.
contiguous
().
reshape
(
n
,
-
1
)
weight_data
.
data
.
copy_
(
_weight
)
#下面是针对模型记录模型出现k和n值
elif
len
(
weight_shapes
)
<
4
:
k
=
weight_data
.
shape
[
1
]
weight_shapes
.
append
({
n
,
k
})
json_file
=
self
.
tritonsingleton
.
get_w8a8json_name
(
n
,
k
)
configs_dict
=
self
.
tritonsingleton
.
get_triton_cache
(
json_file
,
n
,
k
)
if
configs_dict
:
all_json
.
update
(
configs_dict
)
if
self
.
w8a8_strategy
==
1
:
self
.
tritonsingleton
.
triton_json_dict
.
append
(
all_json
)
#找到的所有config都进行一次warmup
for
key
,
value
in
all_json
.
items
():
m
=
int
(
key
.
split
(
'_'
)[
0
])
n
=
int
(
key
.
split
(
'_'
)[
1
])
k
=
int
(
key
.
split
(
'_'
)[
2
])
ops
.
_int8_gemm_helper
(
m
=
m
,
n
=
n
,
k
=
k
,
per_token_act_quant
=
True
,
per_out_channel_weight_quant
=
True
,
use_bias
=
False
,
best_config
=
value
)
vllm/utils.py
View file @
2d5a25cd
...
...
@@ -16,6 +16,7 @@ import threading
import
uuid
import
warnings
import
weakref
import
json
from
asyncio
import
FIRST_COMPLETED
,
ensure_future
from
functools
import
lru_cache
,
partial
,
wraps
from
platform
import
uname
...
...
@@ -1334,3 +1335,88 @@ class AtomicCounter:
@
property
def
value
(
self
):
return
self
.
_value
class
W8a8GetCacheJSON
:
_instance
=
None
def
__new__
(
cls
,
*
args
,
**
kwargs
):
if
cls
.
_instance
is
None
:
cls
.
_instance
=
super
(
W8a8GetCacheJSON
,
cls
).
__new__
(
cls
,
*
args
,
**
kwargs
)
cls
.
_instance
.
_initialize
()
return
cls
.
_instance
def
_initialize
(
self
):
self
.
triton_json_dir
=
(
os
.
getenv
(
'TRITON_JSON_DIR'
,
'./cache'
))
self
.
triton_json_dict
=
[]
def
getspec_config
(
self
,
configs_dict
,
M
,
N
,
K
):
if
f
"
{
M
}
_
{
N
}
_
{
K
}
"
in
configs_dict
:
return
configs_dict
[
f
"
{
M
}
_
{
N
}
_
{
K
}
"
]
else
:
return
None
def
get_triton_cache_tune
(
self
,
file_path
,
n
,
k
):
#tuning的时候使用,当文件不存在时候,则创建文件夹
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
folder_path
=
os
.
path
.
dirname
(
file_path
)
os
.
makedirs
(
folder_path
,
exist_ok
=
True
)
cachedata
=
{}
# 写入空数据到新的JSON文件
with
open
(
file_path
,
'w'
)
as
file
:
json
.
dump
(
cachedata
,
file
)
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'SPLIT_K'
:
int
(
sub_value
[
"SPLIT_K"
]),
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
])
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
get_triton_cache
(
self
,
file_path
,
n
,
k
):
#在非tuning的时候使用,当文件不存在则直接返回none
cache_json_file
=
file_path
if
os
.
path
.
exists
(
file_path
):
#try:
with
open
(
cache_json_file
,
'r'
)
as
file
:
cachedata
=
json
.
load
(
file
)
else
:
return
None
#把所有的cache解析成key:config的形式:[M_N_K]:[config]
configs_dict
=
{}
for
key
,
value
in
cachedata
.
items
():
for
sub_key
,
sub_value
in
value
.
items
():
configs_key
=
f
"
{
sub_key
}
_
{
key
}
"
configs_value
=
{
'SPLIT_K'
:
int
(
sub_value
[
"SPLIT_K"
]),
'BLOCK_SIZE_M'
:
int
(
sub_value
[
"BLOCK_SIZE_M"
]),
'BLOCK_SIZE_N'
:
int
(
sub_value
[
"BLOCK_SIZE_N"
]),
'BLOCK_SIZE_K'
:
int
(
sub_value
[
"BLOCK_SIZE_K"
]),
'GROUP_SIZE_M'
:
int
(
sub_value
[
"GROUP_SIZE_M"
]),
'num_stages'
:
int
(
sub_value
[
'num_stages'
]),
'num_warps'
:
int
(
sub_value
[
'num_warps'
])
}
configs_dict
[
configs_key
]
=
configs_value
return
configs_dict
def
get_w8a8json_name
(
self
,
n
,
k
):
device_name
=
current_platform
.
get_device_name
().
replace
(
" "
,
"_"
)
return
self
.
triton_json_dir
+
f
"/W8A8_
{
n
}
_
{
k
}
_DCU
{
device_name
}
.json"
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment