Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e837b624
Unverified
Commit
e837b624
authored
Aug 16, 2024
by
Charlie Fu
Committed by
GitHub
Aug 16, 2024
Browse files
[Feature][Hardware][Amd] Add fp8 Linear Layer for Rocm (#7210)
parent
ec724a72
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
164 additions
and
49 deletions
+164
-49
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+32
-17
tests/kernels/quant_utils.py
tests/kernels/quant_utils.py
+22
-11
tests/kernels/test_fp8_quant.py
tests/kernels/test_fp8_quant.py
+3
-3
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-1
vllm/config.py
vllm/config.py
+1
-1
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+55
-9
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+47
-7
No files found.
csrc/quantization/fp8/common.cu
View file @
e837b624
...
...
@@ -9,6 +9,18 @@
#include "../../reduction_utils.cuh"
#ifndef USE_ROCM
using
FP8_TYPE
=
c10
::
Float8_e4m3fn
;
C10_HOST_DEVICE
constexpr
auto
FP8_E4M3_MAX
=
std
::
numeric_limits
<
FP8_TYPE
>::
max
();
#else
#include "amd/hip_float8.h"
using
FP8_TYPE
=
c10
::
Float8_e4m3fnuz
;
// Using the default max value from pytorch (240.0) will cause accuracy
// issue when running dynamic quantization. Here use 224.0f for rocm.
constexpr
auto
FP8_E4M3_MAX
=
224.0
f
;
#endif
namespace
vllm
{
__device__
__forceinline__
float
atomicMaxFloat
(
float
*
addr
,
float
value
)
{
...
...
@@ -21,11 +33,9 @@ __device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
return
old
;
}
#define FP8_E4M3_MAX std::numeric_limits<c10::Float8_e4m3fn>::max()
template
<
bool
is_scale_inverted
>
__device__
__forceinline__
c10
::
Float8_e4m3fn
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
__device__
__forceinline__
FP8_TYPE
scaled_fp8_conversion
(
float
const
val
,
float
const
scale
)
{
float
x
=
0.0
f
;
if
constexpr
(
is_scale_inverted
)
{
x
=
val
*
scale
;
...
...
@@ -34,7 +44,13 @@ __device__ __forceinline__ c10::Float8_e4m3fn scaled_fp8_conversion(
}
float
r
=
fmax
(
-
FP8_E4M3_MAX
,
fmin
(
x
,
FP8_E4M3_MAX
));
#ifndef USE_ROCM
return
static_cast
<
c10
::
Float8_e4m3fn
>
(
r
);
#else
// Use hardware cvt instruction for fp8 on rocm
return
c10
::
Float8_e4m3fnuz
(
hip_fp8
(
r
).
data
,
c10
::
Float8_e4m3fnuz
::
from_bits
());
#endif
}
// Compute the absolute maximum m of the input tensor and store
...
...
@@ -74,8 +90,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
// Finally, since cache[0] contains the maximum for this thread block,
// atomically write the max to the target location
if
(
threadIdx
.
x
==
0
)
{
atomicMaxFloat
(
scale
,
cache
[
0
]
/
std
::
numeric_limits
<
c10
::
Float8_e4m3fn
>::
max
());
atomicMaxFloat
(
scale
,
cache
[
0
]
/
FP8_E4M3_MAX
);
}
}
...
...
@@ -88,10 +103,10 @@ struct __align__(8) vec4_t {
};
typedef
struct
__align__
(
4
)
{
c10
::
Float8_e4m3fn
x
;
c10
::
Float8_e4m3fn
y
;
c10
::
Float8_e4m3fn
z
;
c10
::
Float8_e4m3fn
w
;
FP8_TYPE
x
;
FP8_TYPE
y
;
FP8_TYPE
z
;
FP8_TYPE
w
;
}
float8x4_t
;
...
...
@@ -124,7 +139,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
}
template
<
typename
scalar_t
,
bool
is_scale_inverted
>
__device__
void
scaled_fp8_conversion_vec
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
__device__
void
scaled_fp8_conversion_vec
(
FP8_TYPE
*
__restrict__
out
,
scalar_t
const
*
__restrict__
input
,
float
const
scale
,
int64_t
const
num_elems
,
...
...
@@ -160,7 +175,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
}
template
<
typename
scalar_t
>
__global__
void
scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
__global__
void
scaled_fp8_quant_kernel
(
FP8_TYPE
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
float
*
__restrict__
scale
,
int64_t
num_elems
)
{
...
...
@@ -175,7 +190,7 @@ __global__ void scaled_fp8_quant_kernel(c10::Float8_e4m3fn* __restrict__ out,
template
<
typename
scalar_t
>
__global__
void
dynamic_per_token_scaled_fp8_quant_kernel
(
c10
::
Float8_e4m3fn
*
__restrict__
out
,
float
*
__restrict__
scale
,
FP8_TYPE
*
__restrict__
out
,
float
*
__restrict__
scale
,
scalar_t
const
*
__restrict__
input
,
float
const
*
__restrict__
scale_ub
,
const
int
hidden_size
)
{
float
const
min_scaling_factor
=
1.0
f
/
(
FP8_E4M3_MAX
*
512.
f
);
...
...
@@ -184,7 +199,7 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
int
const
token_idx
=
blockIdx
.
x
;
scalar_t
const
*
__restrict__
token_input
=
&
input
[
token_idx
*
hidden_size
];
c10
::
Float8_e4m3fn
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
FP8_TYPE
*
__restrict__
token_output
=
&
out
[
token_idx
*
hidden_size
];
// For vectorization, token_input and token_output pointers need to be
// aligned at 8-byte and 4-byte addresses respectively.
...
...
@@ -241,7 +256,7 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
}
...
...
@@ -261,7 +276,7 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
vllm
::
segmented_max_reduction
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
scale
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
num_elems
);
vllm
::
scaled_fp8_quant_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
FP8_TYPE
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale
.
data_ptr
<
float
>
(),
num_elems
);
});
}
...
...
@@ -284,7 +299,7 @@ void dynamic_per_token_scaled_fp8_quant(
input
.
scalar_type
(),
"dynamic_per_token_scaled_fp8_quant_kernel"
,
[
&
]
{
vllm
::
dynamic_per_token_scaled_fp8_quant_kernel
<
scalar_t
>
<<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
c10
::
Float8_e4m3fn
>
(),
scales
.
data_ptr
<
float
>
(),
out
.
data_ptr
<
FP8_TYPE
>
(),
scales
.
data_ptr
<
float
>
(),
input
.
data_ptr
<
scalar_t
>
(),
scale_ub
.
has_value
()
?
scale_ub
->
data_ptr
<
float
>
()
:
nullptr
,
hidden_size
);
...
...
tests/kernels/quant_utils.py
View file @
e837b624
...
...
@@ -2,6 +2,13 @@ from typing import Optional, Tuple, Union
import
torch
from
vllm.utils
import
is_hip
# Using the default value (240.0) from pytorch will cause accuracy
# issue on dynamic quantization models. Here use 224.0 for rocm.
ROCM_FP8_MAX
=
224.0
FP8_DTYPE
=
torch
.
float8_e4m3fnuz
if
is_hip
()
else
torch
.
float8_e4m3fn
def
as_float32_tensor
(
x
:
Union
[
float
,
torch
.
tensor
])
->
torch
.
tensor
:
return
torch
.
as_tensor
(
x
,
dtype
=
torch
.
float32
,
device
=
'cuda'
)
...
...
@@ -11,13 +18,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
scale_ub
:
Optional
[
torch
.
tensor
]
=
None
)
\
->
Tuple
[
torch
.
tensor
,
torch
.
tensor
]:
assert
quant_dtype
in
[
torch
.
int8
,
torch
.
float8_e4m3fn
]
assert
quant_dtype
in
[
torch
.
int8
,
FP8_DTYPE
]
if
scale_ub
is
not
None
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
assert
quant_dtype
==
FP8_DTYPE
qtype_traits
=
torch
.
iinfo
(
quant_dtype
)
if
quant_dtype
==
torch
.
int8
\
else
torch
.
finfo
(
quant_dtype
)
qtype_max
=
as_float32_tensor
(
qtype_traits
.
max
)
qtype_traits_max
=
ROCM_FP8_MAX
if
is_hip
()
else
qtype_traits
.
max
qtype_traits_min
=
-
ROCM_FP8_MAX
if
is_hip
()
else
qtype_traits
.
min
qtype_max
=
as_float32_tensor
(
qtype_traits_max
)
s_1
=
as_float32_tensor
(
1.0
)
s_512
=
as_float32_tensor
(
512.0
)
...
...
@@ -37,15 +46,15 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
iscales
=
as_float32_tensor
(
s_1
/
scales
)
torch_out
=
as_float32_tensor
(
x
)
*
iscales
torch_out
=
torch_out
.
round
()
torch_out
=
torch_out
.
clamp
(
qtype_traits
.
min
,
qtype_traits
.
max
).
to
(
quant_dtype
)
torch_out
=
torch_out
.
clamp
(
qtype_traits
_
min
,
qtype_traits
_
max
).
to
(
quant_dtype
)
else
:
assert
quant_dtype
==
torch
.
float8_e4m3fn
assert
quant_dtype
==
FP8_DTYPE
min_scaling_factor
=
s_1
/
(
qtype_max
*
s_512
)
scales
=
scales
.
clamp
(
min
=
min_scaling_factor
)
torch_out
=
as_float32_tensor
(
x
)
/
scales
torch_out
=
torch_out
.
clamp
(
qtype_traits
.
min
,
qtype_traits
.
max
).
to
(
quant_dtype
)
torch_out
=
torch_out
.
clamp
(
qtype_traits
_
min
,
qtype_traits
_
max
).
to
(
quant_dtype
)
return
torch_out
,
scales
...
...
@@ -56,8 +65,10 @@ def ref_dynamic_per_token_quant(x: torch.tensor,
def
ref_dynamic_per_tensor_fp8_quant
(
x
:
torch
.
tensor
)
\
->
Tuple
[
torch
.
tensor
,
torch
.
tensor
]:
fp8_traits
=
torch
.
finfo
(
torch
.
float8_e4m3fn
)
fp8_max
=
as_float32_tensor
(
fp8_traits
.
max
)
fp8_traits
=
torch
.
finfo
(
FP8_DTYPE
)
fp8_traits_max
=
ROCM_FP8_MAX
if
is_hip
()
else
fp8_traits
.
max
fp8_traits_min
=
-
ROCM_FP8_MAX
if
is_hip
()
else
fp8_traits
.
min
fp8_max
=
as_float32_tensor
(
fp8_traits_max
)
one
=
as_float32_tensor
(
1.0
)
# For fp8, in order to match the cuda kernel output, we have to do exactly
...
...
@@ -68,5 +79,5 @@ def ref_dynamic_per_tensor_fp8_quant(x: torch.tensor) \
ref_scale
=
x_max
/
fp8_max
ref_iscale
=
one
/
ref_scale
ref_out
=
(
as_float32_tensor
(
x
)
*
ref_iscale
).
clamp
(
fp8_traits
.
min
,
fp8_traits
.
max
).
to
(
dtype
=
torch
.
float8_e4m3fn
)
fp8_traits
_
min
,
fp8_traits
_
max
).
to
(
FP8_DTYPE
)
return
ref_out
,
ref_scale
.
view
((
1
,
))
tests/kernels/test_fp8_quant.py
View file @
e837b624
...
...
@@ -2,7 +2,8 @@ import pytest
import
torch
import
vllm._custom_ops
as
ops
from
tests.kernels.quant_utils
import
(
ref_dynamic_per_tensor_fp8_quant
,
from
tests.kernels.quant_utils
import
(
FP8_DTYPE
,
ref_dynamic_per_tensor_fp8_quant
,
ref_dynamic_per_token_quant
)
DTYPES
=
[
torch
.
half
,
torch
.
bfloat16
,
torch
.
float
]
...
...
@@ -31,8 +32,7 @@ def test_dynamic_per_token_fp8_quant(num_tokens: int, hidden_size: int,
scale_ub
=
torch
.
mean
(
x
).
to
(
dtype
=
torch
.
float32
,
device
=
'cuda'
)
\
if
scale_ub
else
None
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
torch
.
float8_e4m3fn
,
scale_ub
)
ref_out
,
ref_scales
=
ref_dynamic_per_token_quant
(
x
,
FP8_DTYPE
,
scale_ub
)
ops_out
,
ops_scales
=
ops
.
scaled_fp8_quant
(
x
,
scale_ub
=
scale_ub
,
use_per_token_if_dynamic
=
True
)
...
...
vllm/_custom_ops.py
View file @
e837b624
...
...
@@ -369,9 +369,12 @@ def scaled_fp8_quant(
# This code assumes batch_dim and num_tokens are flattened
assert
(
input
.
ndim
==
2
)
shape
:
Union
[
Tuple
[
int
,
int
],
torch
.
Size
]
=
input
.
shape
# For rocm, the output fp8 dtype is torch.float_e3m3fnuz
out_dtype
:
torch
.
dtype
=
torch
.
float8_e4m3fnuz
if
vllm
.
utils
.
is_hip
()
\
else
torch
.
float8_e4m3fn
if
num_token_padding
:
shape
=
(
max
(
num_token_padding
,
input
.
shape
[
0
]),
shape
[
1
])
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
torch
.
float8_e4m3fn
)
output
=
torch
.
empty
(
shape
,
device
=
input
.
device
,
dtype
=
out_dtype
)
if
scale
is
None
:
if
use_per_token_if_dynamic
:
...
...
vllm/config.py
View file @
e837b624
...
...
@@ -240,7 +240,7 @@ class ModelConfig:
def
_verify_quantization
(
self
)
->
None
:
supported_quantization
=
[
*
QUANTIZATION_METHODS
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
]
rocm_supported_quantization
=
[
"gptq"
,
"squeezellm"
,
"fp8"
]
optimized_quantization_methods
=
[
"fp8"
,
"marlin"
,
"gptq_marlin_24"
,
"gptq_marlin"
,
"awq_marlin"
,
"fbgemm_fp8"
,
"compressed_tensors"
,
"compressed-tensors"
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
e837b624
...
...
@@ -20,10 +20,11 @@ from vllm.model_executor.layers.quantization.utils.quant_utils import (
from
vllm.model_executor.layers.quantization.utils.w8a8_utils
import
(
all_close_1d
,
apply_fp8_linear
,
convert_to_channelwise
,
create_per_tensor_scale_param
,
cutlass_fp8_supported
,
per_tensor_dequantize
,
requantize_with_max_scale
)
normalize_e4m3fn_to_e4m3fnuz
,
per_tensor_dequantize
,
requantize_with_max_scale
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
print_warning_once
from
vllm.utils
import
is_hip
,
print_warning_once
ACTIVATION_SCHEMES
=
[
"static"
,
"dynamic"
]
...
...
@@ -120,6 +121,9 @@ class Fp8LinearMethod(LinearMethodBase):
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
self
.
use_marlin
=
capability
<
89
or
envs
.
VLLM_TEST_FORCE_FP8_MARLIN
# Disable marlin for rocm
if
is_hip
():
self
.
use_marlin
=
False
def
create_weights
(
self
,
...
...
@@ -168,6 +172,8 @@ class Fp8LinearMethod(LinearMethodBase):
scale
=
create_per_tensor_scale_param
(
output_partition_sizes
,
**
extra_weight_attrs
)
layer
.
register_parameter
(
"input_scale"
,
scale
)
else
:
layer
.
register_parameter
(
"input_scale"
,
None
)
def
process_weights_after_loading
(
self
,
layer
:
Module
)
->
None
:
# If checkpoint not serialized fp8, quantize the weights.
...
...
@@ -202,9 +208,23 @@ class Fp8LinearMethod(LinearMethodBase):
# requantize the logical shards as a single weight.
else
:
# Dequant -> Quant with max scale so we can run per tensor.
weight
=
layer
.
weight
weight_scale
=
layer
.
weight_scale
# If rocm, use float8_e4m3fnuz.
if
is_hip
():
weight
,
weight_scale
,
input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
weight
=
weight
,
weight_scale
=
weight_scale
,
input_scale
=
layer
.
input_scale
)
if
input_scale
is
not
None
:
layer
.
input_scale
=
Parameter
(
input_scale
,
requires_grad
=
False
)
weight_scale
,
weight
=
requantize_with_max_scale
(
weight
=
layer
.
weight
,
weight_scale
=
layer
.
weight_scale
,
weight
=
weight
,
weight_scale
=
weight_scale
,
logical_widths
=
layer
.
logical_widths
,
)
...
...
@@ -214,8 +234,6 @@ class Fp8LinearMethod(LinearMethodBase):
if
self
.
quant_config
.
activation_scheme
==
"static"
:
layer
.
input_scale
=
Parameter
(
layer
.
input_scale
.
max
(),
requires_grad
=
False
)
else
:
layer
.
input_scale
=
None
if
self
.
use_marlin
:
prepare_fp8_layer_for_marlin
(
layer
)
...
...
@@ -346,10 +364,12 @@ class Fp8MoEMethod(FusedMoEMethodBase):
# If checkpoint is fp16, quantize in place.
if
not
self
.
quant_config
.
is_checkpoint_fp8_serialized
:
# If rocm, use float8_e4m3fnuz as dtype
fp8_dtype
=
torch
.
float8_e4m3fnuz
\
if
is_hip
()
else
torch
.
float8_e4m3fn
w13_weight
=
torch
.
empty_like
(
layer
.
w13_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
torch
.
float8_e4m3fn
)
dtype
=
fp8_dtype
)
w2_weight
=
torch
.
empty_like
(
layer
.
w2_weight
.
data
,
dtype
=
fp8_dtype
)
# Re-initialize w13_scale because we directly quantize
# merged w13 weights and generate a single scaling factor.
...
...
@@ -393,6 +413,32 @@ class Fp8MoEMethod(FusedMoEMethodBase):
layer
.
w13_input_scale
.
max
(),
requires_grad
=
False
)
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
layer
.
w2_input_scale
.
max
(),
requires_grad
=
False
)
# If rocm, normalize the weights and scales to e4m3fnuz
if
is_hip
():
# Normalize the weights and scales
w13_weight
,
w13_weight_scale
,
w13_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w13_weight
,
layer
.
w13_weight_scale
,
layer
.
w13_input_scale
)
w2_weight
,
w2_weight_scale
,
w2_input_scale
=
\
normalize_e4m3fn_to_e4m3fnuz
(
layer
.
w2_weight
,
layer
.
w2_weight_scale
,
layer
.
w2_input_scale
)
# Reset the parameter
layer
.
w13_weight
=
torch
.
nn
.
Parameter
(
w13_weight
,
requires_grad
=
False
)
layer
.
w13_weight_scale
=
torch
.
nn
.
Parameter
(
w13_weight_scale
,
requires_grad
=
False
)
if
w13_input_scale
is
not
None
:
layer
.
w13_input_scale
=
torch
.
nn
.
Parameter
(
w13_input_scale
,
requires_grad
=
False
)
layer
.
w2_weight
=
torch
.
nn
.
Parameter
(
w2_weight
,
requires_grad
=
False
)
layer
.
w2_weight_scale
=
torch
.
nn
.
Parameter
(
w2_weight_scale
,
requires_grad
=
False
)
if
w2_input_scale
is
not
None
:
layer
.
w2_input_scale
=
torch
.
nn
.
Parameter
(
w2_input_scale
,
requires_grad
=
False
)
# Fp8 moe kernel needs single weight scale for w13 per expert.
# We take the max then dequant and requant each expert.
...
...
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
e837b624
...
...
@@ -6,9 +6,19 @@ from torch.nn import Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.utils
import
is_hip
# scaled_mm in pytorch on rocm has a bug that requires always
# providing scaling factor for result. This value is created
# as global value to avoid multiple tensor allocations, and
# can be removed once pytorch fixes the bug.
TORCH_SCALED_MM_SCALE_RESULT
=
torch
.
ones
(
1
).
cuda
()
if
is_hip
()
else
None
def
cutlass_fp8_supported
()
->
bool
:
# cutlass is not supported on Rocm
if
is_hip
():
return
False
capability
=
current_platform
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
...
...
@@ -147,13 +157,19 @@ def apply_fp8_linear(
if
per_tensor_weights
and
per_tensor_activations
:
# Fused GEMM_DQ
output
,
_
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
bias
=
bias
)
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
output
=
torch
.
_scaled_mm
(
qinput
,
weight
,
out_dtype
=
input
.
dtype
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_result
=
TORCH_SCALED_MM_SCALE_RESULT
,
bias
=
bias
)
# Since in torch 2.5, scaled_mm only returns single value
# This should be removed when vllm-nvidia also moves to 2.5
if
is_hip
():
return
torch
.
narrow
(
output
,
0
,
0
,
input
.
shape
[
0
])
return
torch
.
narrow
(
output
[
0
],
0
,
0
,
input
.
shape
[
0
])
else
:
# Fallback for channelwise case, where we use unfused DQ
...
...
@@ -207,3 +223,27 @@ def apply_int8_linear(
scale_b
=
weight_scale
,
out_dtype
=
input
.
dtype
,
bias
=
bias
)
def
normalize_e4m3fn_to_e4m3fnuz
(
weight
:
torch
.
Tensor
,
weight_scale
:
torch
.
Tensor
,
input_scale
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
Optional
[
torch
.
Tensor
]]:
assert
weight
.
dtype
==
torch
.
float8_e4m3fn
# The bits pattern 10000000(-128) represents zero in e4m3fn
# but NaN in e4m3fnuz. So here we set it to 0.
# https://onnx.ai/onnx/technical/float8.html
weight_as_int8
=
weight
.
view
(
torch
.
int8
)
ROCM_FP8_NAN_AS_INT
=
-
128
weight_as_int8
[
weight_as_int8
==
ROCM_FP8_NAN_AS_INT
]
=
0
weight
=
weight_as_int8
.
view
(
torch
.
float8_e4m3fnuz
)
# For the same bits representation, e4m3fnuz value is half of
# the e4m3fn value, so we should double the scaling factor to
# get the same dequantized value.
# https://onnx.ai/onnx/technical/float8.html
weight_scale
=
weight_scale
*
2.0
if
input_scale
is
not
None
:
input_scale
=
input_scale
*
2.0
return
weight
,
weight_scale
,
input_scale
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment