Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
e50109f2
Unverified
Commit
e50109f2
authored
Jul 21, 2025
by
Hubert Lu
Committed by
GitHub
Jul 21, 2025
Browse files
[AMD] Remove vllm's scaled_fp8_quant and moe_sum when SGLANG_USE_AITER=1 (#7484)
parent
69adc4f8
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
156 additions
and
69 deletions
+156
-69
python/sglang/srt/layers/moe/ep_moe/layer.py
python/sglang/srt/layers/moe/ep_moe/layer.py
+1
-4
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
+21
-5
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
...quantization/compressed_tensors/compressed_tensors_moe.py
+3
-2
python/sglang/srt/layers/quantization/fp8.py
python/sglang/srt/layers/quantization/fp8.py
+1
-2
python/sglang/srt/layers/quantization/fp8_kernel.py
python/sglang/srt/layers/quantization/fp8_kernel.py
+115
-46
python/sglang/srt/layers/quantization/unquant.py
python/sglang/srt/layers/quantization/unquant.py
+0
-1
python/sglang/srt/layers/quantization/utils.py
python/sglang/srt/layers/quantization/utils.py
+3
-2
python/sglang/test/test_custom_ops.py
python/sglang/test/test_custom_ops.py
+12
-7
No files found.
python/sglang/srt/layers/moe/ep_moe/layer.py
View file @
e50109f2
...
@@ -54,14 +54,11 @@ _is_npu = is_npu()
...
@@ -54,14 +54,11 @@ _is_npu = is_npu()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
not
_is_npu
:
if
not
(
_is_npu
or
_is_hip
)
:
from
sgl_kernel
import
silu_and_mul
from
sgl_kernel
import
silu_and_mul
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
from
sglang.srt.layers.moe.cutlass_w4a8_moe
import
cutlass_w4a8_moe
if
_is_hip
:
from
vllm._custom_ops
import
scaled_fp8_quant
if
_use_aiter
:
if
_use_aiter
:
from
aiter
import
ActivationType
,
QuantType
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe
import
fused_moe
...
...
python/sglang/srt/layers/moe/fused_moe_triton/fused_moe.py
View file @
e50109f2
...
@@ -39,11 +39,20 @@ _is_hip = is_hip()
...
@@ -39,11 +39,20 @@ _is_hip = is_hip()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
from
sgl_kernel
import
gelu_and_mul
,
silu_and_mul
elif
_is_cpu
and
_is_cpu_amx_available
:
elif
_is_cpu
and
_is_cpu_amx_available
:
pass
pass
elif
_is_hip
:
from
vllm
import
_custom_ops
as
vllm_ops
# gelu_and_mul, silu_and_mul
if
_use_aiter
:
try
:
from
aiter
import
moe_sum
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
else
:
else
:
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm._custom_ops
import
scaled_fp8_quant
...
@@ -1521,11 +1530,7 @@ def fused_experts_impl(
...
@@ -1521,11 +1530,7 @@ def fused_experts_impl(
routed_scaling_factor
:
Optional
[
float
]
=
None
,
routed_scaling_factor
:
Optional
[
float
]
=
None
,
):
):
padded_size
=
padding_size
padded_size
=
padding_size
if
(
if
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
_use_aiter
:
not
(
use_fp8_w8a8
or
use_int8_w8a8
)
or
block_shape
is
not
None
or
(
_is_hip
and
get_bool_env_var
(
"SGLANG_USE_AITER"
))
):
padded_size
=
0
padded_size
=
0
# Check constraints.
# Check constraints.
...
@@ -1723,6 +1728,17 @@ def fused_experts_impl(
...
@@ -1723,6 +1728,17 @@ def fused_experts_impl(
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
routed_scaling_factor
,
routed_scaling_factor
,
)
)
elif
_is_hip
:
if
_use_aiter
:
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
else
:
vllm_ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
out_hidden_states
[
begin_chunk_idx
:
end_chunk_idx
],
)
else
:
else
:
vllm_ops
.
moe_sum
(
vllm_ops
.
moe_sum
(
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
intermediate_cache3
.
view
(
*
intermediate_cache3
.
shape
),
...
...
python/sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py
View file @
e50109f2
...
@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import (
...
@@ -20,7 +20,7 @@ from sglang.srt.layers.quantization.utils import (
per_tensor_dequantize
,
per_tensor_dequantize
,
replace_parameter
,
replace_parameter
,
)
)
from
sglang.srt.utils
import
is_cpu
,
is_cuda
,
is_npu
,
set_weight_attrs
from
sglang.srt.utils
import
is_cpu
,
is_cuda
,
is_hip
,
is_npu
,
set_weight_attrs
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.moe.topk
import
TopKOutput
from
sglang.srt.layers.moe.topk
import
TopKOutput
...
@@ -32,8 +32,9 @@ _is_cuda = is_cuda()
...
@@ -32,8 +32,9 @@ _is_cuda = is_cuda()
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_is_hip
=
is_hip
()
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)):
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)
or
_is_hip
):
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm
import
_custom_ops
as
vllm_ops
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm._custom_ops
import
scaled_fp8_quant
...
...
python/sglang/srt/layers/quantization/fp8.py
View file @
e50109f2
...
@@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...
@@ -95,10 +95,9 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if
_is_hip
and
(
_use_aiter
or
_use_hip_int4
):
if
_is_hip
and
(
_use_aiter
or
_use_hip_int4
):
from
aiter
import
ActivationType
,
QuantType
from
aiter
import
ActivationType
,
QuantType
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe_bf16_asm
import
asm_moe
,
ck_moe_2stages
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.shuffle
import
shuffle_weight
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)):
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)
or
_is_hip
):
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm._custom_ops
import
scaled_fp8_quant
...
...
python/sglang/srt/layers/quantization/fp8_kernel.py
View file @
e50109f2
...
@@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
...
@@ -27,6 +27,7 @@ from sglang.srt.layers.quantization import deep_gemm_wrapper
from
sglang.srt.utils
import
(
from
sglang.srt.utils
import
(
align
,
align
,
direct_register_custom_op
,
direct_register_custom_op
,
get_bool_env_var
,
get_device_core_count
,
get_device_core_count
,
get_device_name
,
get_device_name
,
is_cpu
,
is_cpu
,
...
@@ -39,6 +40,7 @@ from sglang.srt.utils import (
...
@@ -39,6 +40,7 @@ from sglang.srt.utils import (
_is_hip
=
is_hip
()
_is_hip
=
is_hip
()
_is_cuda
=
is_cuda
()
_is_cuda
=
is_cuda
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_use_aiter
=
get_bool_env_var
(
"SGLANG_USE_AITER"
)
and
_is_hip
if
_is_cuda
:
if
_is_cuda
:
from
sgl_kernel
import
(
from
sgl_kernel
import
(
...
@@ -47,6 +49,22 @@ if _is_cuda:
...
@@ -47,6 +49,22 @@ if _is_cuda:
sgl_per_token_quant_fp8
,
sgl_per_token_quant_fp8
,
)
)
if
_is_hip
:
if
_use_aiter
:
try
:
from
aiter
import
(
# v0.1.3
dynamic_per_tensor_quant
,
dynamic_per_token_scaled_quant
,
static_per_tensor_quant
,
)
except
ImportError
:
raise
ImportError
(
"aiter is required when SGLANG_USE_AITER is set to True"
)
else
:
try
:
import
vllm._C
except
ImportError
:
raise
ImportError
(
"vllm is required when SGLANG_USE_AITER is set to False"
)
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
...
@@ -1116,16 +1134,10 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
...
@@ -1116,16 +1134,10 @@ def per_token_group_quant_mla_deep_gemm_masked_fp8(
return
x_q
,
x_s
.
transpose
(
1
,
2
),
masked_m
,
m
,
aligned_m
return
x_q
,
x_s
.
transpose
(
1
,
2
),
masked_m
,
m
,
aligned_m
def
scaled_fp8_quant
(
"""
input
:
torch
.
Tensor
,
Quantize input tensor to FP8 (8-bit floating point) format.
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
Quantize input tensor to FP8 (8-bit floating point) format.
Args:
Args:
input (torch.Tensor): Input tensor to be quantized
input (torch.Tensor): Input tensor to be quantized
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
If None, scales will be computed dynamically.
If None, scales will be computed dynamically.
...
@@ -1136,14 +1148,22 @@ def scaled_fp8_quant(
...
@@ -1136,14 +1148,22 @@ def scaled_fp8_quant(
- True: compute scale per token
- True: compute scale per token
- False: compute single scale per tensor
- False: compute single scale per tensor
Returns:
Returns:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- quantized_tensor: The FP8 quantized version of input
- quantized_tensor: The FP8 quantized version of input
- scale_tensor: The scaling factors used for quantization
- scale_tensor: The scaling factors used for quantization
Raises:
Raises:
AssertionError: If input is not 2D or if static scale's numel != 1
AssertionError: If input is not 2D or if static scale's numel != 1
"""
"""
if
_is_hip
:
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
shape
=
input
.
shape
if
num_token_padding
:
if
num_token_padding
:
...
@@ -1153,7 +1173,54 @@ def scaled_fp8_quant(
...
@@ -1153,7 +1173,54 @@ def scaled_fp8_quant(
if
scale
is
None
:
if
scale
is
None
:
# Dynamic scaling
# Dynamic scaling
if
use_per_token_if_dynamic
:
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
((
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
empty
(
(
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
if
_use_aiter
:
dynamic_per_token_scaled_quant
(
output
,
input
,
scale
)
else
:
torch
.
ops
.
_C
.
dynamic_per_token_scaled_fp8_quant
(
output
,
input
.
contiguous
(),
scale
,
None
)
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
if
_use_aiter
:
dynamic_per_tensor_quant
(
output
,
input
,
scale
)
else
:
torch
.
ops
.
_C
.
dynamic_scaled_fp8_quant
(
output
,
input
,
scale
)
else
:
# Static scaling
assert
(
scale
.
numel
()
==
1
),
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
if
_use_aiter
:
static_per_tensor_quant
(
output
,
input
,
scale
)
else
:
torch
.
ops
.
_C
.
static_scaled_fp8_quant
(
output
,
input
,
scale
)
return
output
,
scale
else
:
def
scaled_fp8_quant
(
input
:
torch
.
Tensor
,
scale
:
Optional
[
torch
.
Tensor
]
=
None
,
num_token_padding
:
Optional
[
int
]
=
None
,
use_per_token_if_dynamic
:
bool
=
False
,
)
->
tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
assert
input
.
ndim
==
2
,
f
"Expected 2D input tensor, got
{
input
.
ndim
}
D"
shape
=
input
.
shape
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
fp8_dtype
)
if
scale
is
None
:
# Dynamic scaling
if
use_per_token_if_dynamic
:
scale
=
torch
.
empty
(
(
shape
[
0
],
1
),
device
=
input
.
device
,
dtype
=
torch
.
float32
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
sgl_per_token_quant_fp8
(
input
,
output
,
scale
)
else
:
else
:
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
scale
=
torch
.
zeros
(
1
,
device
=
input
.
device
,
dtype
=
torch
.
float32
)
...
@@ -1162,7 +1229,9 @@ def scaled_fp8_quant(
...
@@ -1162,7 +1229,9 @@ def scaled_fp8_quant(
)
# False for dynamic
)
# False for dynamic
else
:
else
:
# Static scaling
# Static scaling
assert
scale
.
numel
()
==
1
,
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
assert
(
scale
.
numel
()
==
1
),
f
"Expected scalar scale, got numel=
{
scale
.
numel
()
}
"
sgl_per_tensor_quant_fp8
(
sgl_per_tensor_quant_fp8
(
input
,
output
,
scale
,
is_static
=
True
input
,
output
,
scale
,
is_static
=
True
)
# True for static
)
# True for static
...
...
python/sglang/srt/layers/quantization/unquant.py
View file @
e50109f2
...
@@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
...
@@ -37,7 +37,6 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if
_use_aiter
:
if
_use_aiter
:
from
aiter
import
ActivationType
from
aiter
import
ActivationType
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe
import
fused_moe
from
aiter.fused_moe_bf16_asm
import
ck_moe_2stages
from
aiter.ops.shuffle
import
shuffle_weight
from
aiter.ops.shuffle
import
shuffle_weight
...
...
python/sglang/srt/layers/quantization/utils.py
View file @
e50109f2
...
@@ -12,7 +12,7 @@ import torch
...
@@ -12,7 +12,7 @@ import torch
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
,
scalar_types
from
sglang.srt.layers.quantization.scalar_type
import
ScalarType
,
scalar_types
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_npu
from
sglang.srt.utils
import
cpu_has_amx_support
,
is_cpu
,
is_cuda
,
is_hip
,
is_npu
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
from
sglang.srt.layers.quantization.base_config
import
QuantizationConfig
...
@@ -21,8 +21,9 @@ _is_cuda = is_cuda()
...
@@ -21,8 +21,9 @@ _is_cuda = is_cuda()
_is_npu
=
is_npu
()
_is_npu
=
is_npu
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu_amx_available
=
cpu_has_amx_support
()
_is_cpu
=
is_cpu
()
_is_cpu
=
is_cpu
()
_is_hip
=
is_hip
()
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)):
if
not
(
_is_cuda
or
_is_npu
or
(
_is_cpu
and
_is_cpu_amx_available
)
or
_is_hip
):
from
vllm._custom_ops
import
scaled_fp8_quant
from
vllm._custom_ops
import
scaled_fp8_quant
...
...
python/sglang/test/test_custom_ops.py
View file @
e50109f2
...
@@ -3,8 +3,13 @@
...
@@ -3,8 +3,13 @@
import
pytest
import
pytest
import
torch
import
torch
from
sglang.srt.layers.quantization.fp8_kernel
import
scaled_fp8_quant
from
sglang.srt.layers.quantization.fp8_kernel
import
is_fp8_fnuz
,
scaled_fp8_quant
from
sglang.srt.utils
import
is_cuda
from
sglang.srt.utils
import
is_cuda
,
is_hip
_is_cuda
=
is_cuda
()
_is_hip
=
is_hip
()
_is_fp8_fnuz
=
is_fp8_fnuz
()
fp8_dtype
=
torch
.
float8_e4m3fnuz
if
_is_fp8_fnuz
else
torch
.
float8_e4m3fn
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
...
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
...
@@ -13,10 +18,10 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
def
quantize_ref_per_tensor
(
tensor
,
inv_scale
):
def
quantize_ref_per_tensor
(
tensor
,
inv_scale
):
# The reference implementation that fully aligns to
# The reference implementation that fully aligns to
# the kernel being tested.
# the kernel being tested.
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
fp8_dtype
)
scale
=
inv_scale
.
reciprocal
()
scale
=
inv_scale
.
reciprocal
()
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
)
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
qweight
=
qweight
.
to
(
fp8_dtype
)
return
qweight
return
qweight
def
dequantize_per_tensor
(
tensor
,
inv_scale
,
dtype
):
def
dequantize_per_tensor
(
tensor
,
inv_scale
,
dtype
):
...
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
...
@@ -48,19 +53,19 @@ def test_scaled_fp8_quant_per_tensor(dtype) -> None:
)
)
if
is_cuda
:
if
_
is_cuda
or
_is_hip
:
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
def
test_scaled_fp8_quant_per_token_dynamic
(
dtype
)
->
None
:
def
test_scaled_fp8_quant_per_token_dynamic
(
dtype
)
->
None
:
def
quantize_ref_per_token
(
tensor
,
inv_scale
):
def
quantize_ref_per_token
(
tensor
,
inv_scale
):
# The reference implementation that fully aligns to
# The reference implementation that fully aligns to
# the kernel being tested.
# the kernel being tested.
finfo
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
finfo
=
torch
.
finfo
(
fp8_dtype
)
scale
=
inv_scale
.
reciprocal
()
scale
=
inv_scale
.
reciprocal
()
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
qweight
=
(
tensor
.
to
(
torch
.
float32
)
*
scale
).
clamp
(
min
=
finfo
.
min
,
max
=
finfo
.
max
min
=
finfo
.
min
,
max
=
finfo
.
max
)
)
qweight
=
qweight
.
to
(
torch
.
float8_e4m3fn
)
qweight
=
qweight
.
to
(
fp8_dtype
)
return
qweight
return
qweight
def
dequantize_per_token
(
tensor
,
inv_scale
,
dtype
):
def
dequantize_per_token
(
tensor
,
inv_scale
,
dtype
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment