Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
2009d4a1
Commit
2009d4a1
authored
Oct 22, 2024
by
zhuwenwen
Browse files
增加w8a8的triton环境变量控制
parent
f7512877
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
7 deletions
+46
-7
vllm/_custom_ops.py
vllm/_custom_ops.py
+19
-0
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
...ompressed_tensors/schemes/compressed_tensors_w8a8_int8.py
+5
-1
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+22
-6
No files found.
vllm/_custom_ops.py
View file @
2009d4a1
...
@@ -715,6 +715,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
...
@@ -715,6 +715,25 @@ def cutlass_scaled_mm(a: torch.Tensor,
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
rocblas_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
rocblas_scaled_mm_nn
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
triton_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
torch
.
dtype
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
quant_ops
.
triton_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
)
def
cutlass_scaled_mm_azp
(
a
:
torch
.
Tensor
,
def
cutlass_scaled_mm_azp
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py
View file @
2009d4a1
...
@@ -3,6 +3,7 @@ from typing import Callable, List, Optional
...
@@ -3,6 +3,7 @@ from typing import Callable, List, Optional
import
torch
import
torch
from
compressed_tensors.quantization
import
QuantizationStrategy
from
compressed_tensors.quantization
import
QuantizationStrategy
from
torch.nn
import
Parameter
from
torch.nn
import
Parameter
import
os
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
from
vllm.model_executor.layers.quantization.compressed_tensors.schemes
import
(
...
@@ -24,6 +25,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -24,6 +25,7 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
self
.
strategy
=
strategy
self
.
strategy
=
strategy
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
is_static_input_scheme
=
is_static_input_scheme
self
.
input_symmetric
=
input_symmetric
self
.
input_symmetric
=
input_symmetric
self
.
w8a8_strategy
=
int
(
os
.
getenv
(
'W8A8_SUPPORT_METHODS'
,
'0'
))
@
classmethod
@
classmethod
def
get_min_capability
(
cls
)
->
int
:
def
get_min_capability
(
cls
)
->
int
:
...
@@ -145,4 +147,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
...
@@ -145,4 +147,6 @@ class CompressedTensorsW8A8Int8(CompressedTensorsScheme):
input_scale
=
layer
.
input_scale
,
input_scale
=
layer
.
input_scale
,
input_zero_point
=
layer
.
input_zero_point
,
input_zero_point
=
layer
.
input_zero_point
,
azp_adj
=
layer
.
azp_adj
,
azp_adj
=
layer
.
azp_adj
,
bias
=
bias
)
bias
=
bias
,
w8a8_strategy
=
self
.
w8a8_strategy
)
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
2009d4a1
...
@@ -195,6 +195,7 @@ def apply_int8_linear(
...
@@ -195,6 +195,7 @@ def apply_int8_linear(
input_zero_point
:
Optional
[
torch
.
Tensor
]
=
None
,
input_zero_point
:
Optional
[
torch
.
Tensor
]
=
None
,
azp_adj
:
Optional
[
torch
.
Tensor
]
=
None
,
azp_adj
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
w8a8_strategy
:
Optional
[
int
]
=
0
,
):
):
# ops.scaled_int8_quant supports both dynamic and static quant.
# ops.scaled_int8_quant supports both dynamic and static quant.
# * dynamic, layer.input_scale is None and x_scale computed from x.
# * dynamic, layer.input_scale is None and x_scale computed from x.
...
@@ -214,12 +215,27 @@ def apply_int8_linear(
...
@@ -214,12 +215,27 @@ def apply_int8_linear(
azp_adj
=
azp_adj
,
azp_adj
=
azp_adj
,
azp
=
x_zp
,
azp
=
x_zp
,
bias
=
bias
)
bias
=
bias
)
return
ops
.
cutlass_scaled_mm
(
x_q
,
if
w8a8_strategy
==
1
:
weight
,
return
ops
.
triton_scaled_mm
(
x_q
,
scale_a
=
x_scale
,
weight
,
scale_b
=
weight_scale
,
scale_a
=
x_scale
,
out_dtype
=
input
.
dtype
,
scale_b
=
weight_scale
,
bias
=
bias
)
out_dtype
=
input
.
dtype
,
bias
=
bias
)
elif
w8a8_strategy
==
2
:
return
ops
.
cutlass_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
else
:
return
ops
.
rocblas_scaled_mm
(
x_q
,
weight
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
def
normalize_e4m3fn_to_e4m3fnuz
(
def
normalize_e4m3fn_to_e4m3fnuz
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment