Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xuwx1
LightX2V
Commits
3aa95081
Commit
3aa95081
authored
Apr 09, 2025
by
helloyongyang
Browse files
[Feature]: support many quant kernels
parent
9a686a73
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
280 additions
and
64 deletions
+280
-64
lightx2v/common/ops/mm/mm_weight.py
lightx2v/common/ops/mm/mm_weight.py
+280
-64
No files found.
lightx2v/common/ops/mm/mm_weight.py
View file @
3aa95081
import
torch
import
torch
from
abc
import
ABCMeta
,
abstractmethod
from
abc
import
ABCMeta
,
abstractmethod
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
import
sgl_kernel
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.registry_factory
import
MM_WEIGHT_REGISTER
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
from
lightx2v.utils.quant_utils
import
IntegerQuantizer
,
FloatQuantizer
...
@@ -9,6 +10,11 @@ try:
...
@@ -9,6 +10,11 @@ try:
except
ImportError
:
except
ImportError
:
Q8F
=
None
Q8F
=
None
try
:
import
deep_gemm
except
ImportError
:
deep_gemm
=
None
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
class
MMWeightTemplate
(
metaclass
=
ABCMeta
):
def
__init__
(
self
,
weight_name
,
bias_name
):
def
__init__
(
self
,
weight_name
,
bias_name
):
...
@@ -70,8 +76,102 @@ class MMWeightForceFP32(MMWeight):
...
@@ -70,8 +76,102 @@ class MMWeightForceFP32(MMWeight):
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
self
.
bias
=
self
.
bias
.
to
(
torch
.
float32
)
class
MMWeightQuantTemplate
(
MMWeightTemplate
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
None
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
None
"""
weight load functions
"""
def
load
(
self
,
weight_dict
):
self
.
load_func
(
weight_dict
)
if
self
.
weight_need_transpose
:
self
.
weight
=
self
.
weight
.
t
()
def
load_quantized
(
self
,
weight_dict
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
".weight_scale"
].
cuda
()
def
load_fp8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
True
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
).
cuda
()
w_quantizer
=
FloatQuantizer
(
"e4m3"
,
True
,
"per_channel"
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
else
:
self
.
load_quantized
(
weight_dict
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
load_int8_perchannel_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
True
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
)
w_quantizer
=
IntegerQuantizer
(
8
,
True
,
"per_channel"
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
else
:
self
.
load_quantized
(
weight_dict
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
load_fp8_perblock128_sym
(
self
,
weight_dict
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
True
):
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight
,
self
.
weight_scale
=
self
.
per_block_cast_to_fp8
(
self
.
weight
)
else
:
self
.
load_quantized
(
weight_dict
)
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
per_block_cast_to_fp8
(
self
,
x
):
assert
x
.
dim
()
==
2
m
,
n
=
x
.
shape
x_padded
=
torch
.
zeros
((
deep_gemm
.
ceil_div
(
m
,
128
)
*
128
,
deep_gemm
.
ceil_div
(
n
,
128
)
*
128
),
dtype
=
x
.
dtype
,
device
=
x
.
device
)
x_padded
[:
m
,
:
n
]
=
x
x_view
=
x_padded
.
view
(
-
1
,
128
,
x_padded
.
size
(
1
)
//
128
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
(
1
,
3
),
keepdim
=
True
).
clamp
(
1e-4
)
x_scaled
=
(
x_view
*
(
448.0
/
x_amax
)).
to
(
torch
.
float8_e4m3fn
)
return
x_scaled
.
view_as
(
x_padded
)[:
m
,
:
n
].
contiguous
(),
(
x_amax
/
448.0
).
view
(
x_view
.
size
(
0
),
x_view
.
size
(
2
))
"""
act quant kernels
"""
def
act_quant_fp8_perchannel_sym_vllm
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
=
ops
.
scaled_fp8_quant
(
x
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannel_sym_sgl
(
self
,
x
):
m
,
k
=
x
.
shape
input_tensor_quant
=
torch
.
empty
((
m
,
k
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
,
requires_grad
=
False
)
input_tensor_scale
=
torch
.
empty
((
m
,
1
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
False
)
sgl_kernel
.
sgl_per_token_quant_fp8
(
x
,
input_tensor_quant
,
input_tensor_scale
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_int8_perchannel_sym_vllm
(
self
,
x
):
input_tensor_quant
,
input_tensor_scale
,
_
=
ops
.
scaled_int8_quant
(
x
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
return
input_tensor_quant
,
input_tensor_scale
def
act_quant_fp8_perchannelgroup128_sym_deepgemm
(
self
,
x
):
assert
x
.
dim
()
==
2
and
x
.
size
(
1
)
%
128
==
0
m
,
n
=
x
.
shape
x_view
=
x
.
view
(
m
,
-
1
,
128
)
x_amax
=
x_view
.
abs
().
float
().
amax
(
dim
=
2
).
view
(
m
,
-
1
).
clamp
(
1e-4
)
return
(
x_view
*
(
448.0
/
x_amax
.
unsqueeze
(
2
))).
to
(
torch
.
float8_e4m3fn
).
view
(
m
,
n
),
(
x_amax
/
448.0
).
view
(
m
,
-
1
)
def
act_quant_fp8_perchannelgroup128_sym_sgl
(
self
,
x
):
m
,
k
=
x
.
shape
input_tensor_quant
=
torch
.
empty
((
m
,
k
),
dtype
=
torch
.
float8_e4m3fn
,
device
=
"cuda"
,
requires_grad
=
False
)
input_tensor_scale
=
torch
.
empty
((
m
,
k
//
128
),
dtype
=
torch
.
float32
,
device
=
"cuda"
,
requires_grad
=
False
)
sgl_kernel
.
sgl_per_token_group_quant_fp8
(
x
,
input_tensor_quant
,
input_tensor_scale
,
group_size
=
128
,
eps
=
1e-10
,
fp8_min
=-
448.0
,
fp8_max
=
448.0
)
return
input_tensor_quant
,
input_tensor_scale
@
MM_WEIGHT_REGISTER
(
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
)
@
MM_WEIGHT_REGISTER
(
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm"
)
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeightTemplate
):
class
MMWeightWfp8channelAfp8channeldynamicVllm
(
MMWeight
Quant
Template
):
"""
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm
...
@@ -83,31 +183,23 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
...
@@ -83,31 +183,23 @@ class MMWeightWfp8channelAfp8channeldynamicVllm(MMWeightTemplate):
def
__init__
(
self
,
weight_name
,
bias_name
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
def
load
(
self
,
weight_dict
):
self
.
weight_need_transpose
=
True
if
self
.
config
.
get
(
"weight_auto_quant"
,
True
):
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
).
cuda
()
w_quantizer
=
FloatQuantizer
(
"e4m3"
,
True
,
"per_channel"
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
).
t
().
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
).
cuda
()
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
".weight_scale"
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_tensor
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
qinput
,
self
.
weight
,
x_scale
,
self
.
weight_scale
,
self
.
bias
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
bias
)
return
output_tensor
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
)
@
MM_WEIGHT_REGISTER
(
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm"
)
class
MMWeightWint8channelAint8channeldynamicVllm
(
MMWeightTemplate
):
class
MMWeightWint8channelAint8channeldynamicVllm
(
MMWeight
Quant
Template
):
"""
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Vllm
...
@@ -119,31 +211,46 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
...
@@ -119,31 +211,46 @@ class MMWeightWint8channelAint8channeldynamicVllm(MMWeightTemplate):
def
__init__
(
self
,
weight_name
,
bias_name
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
def
load
(
self
,
weight_dict
):
self
.
weight_need_transpose
=
True
if
self
.
config
.
get
(
"weight_auto_quant"
,
True
):
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
self
.
weight
=
weight_dict
[
self
.
weight_name
].
to
(
torch
.
float32
).
cuda
()
w_quantizer
=
IntegerQuantizer
(
8
,
True
,
"per_channel"
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
).
t
().
cuda
()
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
).
cuda
()
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
t
().
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
".weight_scale"
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
qinput
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input_tensor
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
qinput
,
self
.
weight
,
x_scale
,
self
.
weight_scale
,
self
.
bias
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
bias
)
return
output_tensor
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F"
)
class
MMWeightWfp8channelAfp8channeldynamicQ8F
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Q8F
"""
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_vllm
def
apply
(
self
,
input_tensor
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
Q8F
.
linear
.
fp8_linear
(
input_tensor_quant
,
self
.
weight
,
self
.
bias
,
input_tensor_scale
,
self
.
weight_scale
,
out_dtype
=
torch
.
bfloat16
)
return
output_tensor
.
squeeze
(
0
)
@
MM_WEIGHT_REGISTER
(
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
)
@
MM_WEIGHT_REGISTER
(
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F"
)
class
MMWeightWint8channelAint8channeldynamicQ8F
(
MMWeightTemplate
):
class
MMWeightWint8channelAint8channeldynamicQ8F
(
MMWeight
Quant
Template
):
"""
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Q8F
...
@@ -155,55 +262,164 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
...
@@ -155,55 +262,164 @@ class MMWeightWint8channelAint8channeldynamicQ8F(MMWeightTemplate):
def
__init__
(
self
,
weight_name
,
bias_name
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
def
load
(
self
,
weight_dict
):
def
apply
(
self
,
input_tensor
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
True
):
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
output_tensor
=
Q8F
.
linear
.
q8_linear
(
input_tensor_quant
,
self
.
weight
,
self
.
bias
,
input_tensor_scale
,
self
.
weight_scale
,
fuse_gelu
=
False
,
out_dtype
=
torch
.
bfloat16
)
w_quantizer
=
IntegerQuantizer
(
8
,
True
,
"per_channel"
)
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
int8
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
else
:
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
".weight_scale"
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
float
().
cuda
()
if
self
.
bias_name
is
not
None
else
None
def
apply
(
self
,
input_tensor
,
act
=
None
):
qinput
,
x_scale
,
_
=
ops
.
scaled_int8_quant
(
input_tensor
,
scale
=
None
,
azp
=
None
,
symmetric
=
True
)
output_tensor
=
Q8F
.
linear
.
q8_linear
(
qinput
,
self
.
weight
,
self
.
bias
,
x_scale
,
self
.
weight_scale
,
fuse_gelu
=
False
,
out_dtype
=
torch
.
bfloat16
)
return
output_tensor
.
squeeze
(
0
)
return
output_tensor
.
squeeze
(
0
)
@
MM_WEIGHT_REGISTER
(
"W-fp8-
channel
-sym-A-fp8-channel-sym-dynamic-
Q8F
"
)
@
MM_WEIGHT_REGISTER
(
"W-fp8-
block128
-sym-A-fp8-channel-
group128-
sym-dynamic-
Deepgemm
"
)
class
MMWeightWfp8
channel
Afp8channel
dynamicQ8F
(
MMWeightTemplate
):
class
MMWeightWfp8
block128
Afp8channel
group128dynamicDeepgemm
(
MMWeight
Quant
Template
):
"""
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Q8F
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 perchannel-pergroup group=128 dynamic sym
Kernel: Deepgemm
Reference: https://github.com/deepseek-ai/DeepGEMM
Example:
Act(1024, 2048) x Weight(2048, 4096) = Out(1024, 4096)
Act : torch.Size([1024, 2048]), torch.float8_e4m3fn
Act Scale: torch.Size([1024, 16]), torch.float32
Weight : torch.Size([4096, 2048]), torch.float8_e4m3fn
Weight Scale: torch.Size([32, 16]), torch.float32
Out : torch.Size([1024, 4096]), torch.bfloat16
"""
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_fp8_perblock128_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannelgroup128_sym_deepgemm
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
0
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
input_tensor_quant
,
input_tensor_scale
),
(
self
.
weight
,
self
.
weight_scale
),
output_tensor
)
if
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl"
)
class
MMWeightWfp8block128Afp8channelgroup128dynamicDeepgemmActSgl
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-block128-sym-A-fp8-channel-group128-sym-dynamic-Deepgemm-ActSgl
Quant MM:
Weight: fp8 perblock 128x128 sym
Act: fp8 pertoken-pergroup group=128 dynamic sym
Kernel: quant-mm using Deepgemm, act dynamic quant using Sgl-kernel
"""
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_fp8_perblock128_sym
self
.
weight_need_transpose
=
False
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannelgroup128_sym_sgl
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
0
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
deep_gemm
.
gemm_fp8_fp8_bf16_nt
((
input_tensor_quant
,
input_tensor_scale
),
(
self
.
weight
,
self
.
weight_scale
),
output_tensor
)
if
self
.
bias
is
not
None
:
output_tensor
.
add_
(
self
.
bias
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl"
)
class
MMWeightWfp8channelAfp8channeldynamicVllmActSgl
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Vllm-ActSgl
Quant MM:
Quant MM:
Weight: fp8 perchannel sym
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Act: fp8 perchannel dynamic sym
Kernel:
Q8F
Kernel:
quant-mm using vllm, act dynamic quant using Sgl-kernel
"""
"""
def
__init__
(
self
,
weight_name
,
bias_name
):
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_sgl
def
load
(
self
,
weight_dict
):
def
apply
(
self
,
input_tensor
):
if
self
.
config
.
get
(
"weight_auto_quant"
,
True
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
dtype
=
input_tensor
.
dtype
w_quantizer
=
FloatQuantizer
(
"e4m3"
,
True
,
"per_channel"
)
device
=
input_tensor
.
device
self
.
weight
,
self
.
weight_scale
,
_
=
w_quantizer
.
real_quant_tensor
(
self
.
weight
)
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
self
.
weight
=
self
.
weight
.
to
(
torch
.
float8_e4m3fn
)
self
.
weight_scale
=
self
.
weight_scale
.
to
(
torch
.
float32
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
else
:
torch
.
ops
.
_C
.
cutlass_scaled_mm
(
output_tensor
,
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
self
.
bias
)
self
.
weight
=
weight_dict
[
self
.
weight_name
].
cuda
()
return
output_tensor
self
.
weight_scale
=
weight_dict
[
self
.
weight_name
.
rstrip
(
".weight"
)
+
".weight_scale"
].
cuda
()
self
.
bias
=
weight_dict
[
self
.
bias_name
].
float
().
cuda
()
if
self
.
bias_name
is
not
None
else
None
@
MM_WEIGHT_REGISTER
(
"W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl"
)
class
MMWeightWfp8channelAfp8channeldynamicSgl
(
MMWeightQuantTemplate
):
"""
Name: W-fp8-channel-sym-A-fp8-channel-sym-dynamic-Sgl
Quant MM:
Weight: fp8 perchannel sym
Act: fp8 perchannel dynamic sym
Kernel: Sgl-kernel
"""
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_fp8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_fp8_perchannel_sym_sgl
def
apply
(
self
,
input_tensor
):
def
apply
(
self
,
input_tensor
):
qinput
,
x_scale
=
ops
.
scaled_fp8_quant
(
input_tensor
,
None
,
scale_ub
=
None
,
use_per_token_if_dynamic
=
True
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
Q8F
.
linear
.
fp8_linear
(
qinput
,
self
.
weight
,
self
.
bias
,
x_scale
,
self
.
weight_scale
,
out_dtype
=
torch
.
bfloat16
)
output_tensor
=
sgl_kernel
.
fp8_scaled_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
torch
.
bfloat16
,
bias
=
self
.
bias
)
return
output_tensor
.
squeeze
(
0
)
return
output_tensor
@
MM_WEIGHT_REGISTER
(
"W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm"
)
class
MMWeightWint8channelAint8channeldynamicActVllm
(
MMWeightQuantTemplate
):
"""
Name: W-int8-channel-sym-A-int8-channel-sym-dynamic-Sgl-ActVllm
Quant MM:
Weight: int8 perchannel sym
Act: int8 perchannel dynamic sym
Kernel: quant-mm using Sgl-kernel, act dynamic quant using vllm
"""
def
__init__
(
self
,
weight_name
,
bias_name
):
super
().
__init__
(
weight_name
,
bias_name
)
self
.
load_func
=
self
.
load_int8_perchannel_sym
self
.
weight_need_transpose
=
True
self
.
act_quant_func
=
self
.
act_quant_int8_perchannel_sym_vllm
def
apply
(
self
,
input_tensor
):
shape
=
(
input_tensor
.
shape
[
0
],
self
.
weight
.
shape
[
1
])
dtype
=
input_tensor
.
dtype
device
=
input_tensor
.
device
output_tensor
=
torch
.
empty
(
shape
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
False
)
input_tensor_quant
,
input_tensor_scale
=
self
.
act_quant_func
(
input_tensor
)
output_tensor
=
sgl_kernel
.
int8_scaled_mm
(
input_tensor_quant
,
self
.
weight
,
input_tensor_scale
,
self
.
weight_scale
,
torch
.
bfloat16
,
self
.
bias
)
return
output_tensor
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment