Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fea59c77
Unverified
Commit
fea59c77
authored
Jul 22, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jul 22, 2024
Browse files
[Bugfix][Kernel] Use int64_t for indices in fp8 quant kernels (#6649)
parent
739b61a3
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
6 deletions
+31
-6
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+6
-6
tests/kernels/test_fp8_quant.py
tests/kernels/test_fp8_quant.py
+25
-0
No files found.
csrc/quantization/fp8/common.cu
View file @
fea59c77
...
@@ -103,11 +103,11 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
...
@@ -103,11 +103,11 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int
const
num_vec_elems
=
num_elems
>>
2
;
int
64_t
const
num_vec_elems
=
num_elems
>>
2
;
float
absmax_val
=
0.0
f
;
float
absmax_val
=
0.0
f
;
#pragma unroll 4
#pragma unroll 4
for
(
int
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
for
(
int
64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
...
@@ -116,7 +116,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
...
@@ -116,7 +116,7 @@ __device__ float thread_max_vec(scalar_t const* __restrict__ input,
}
}
// Handle the remaining elements if num_elems is not divisible by 4
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
for
(
int
64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
}
}
...
@@ -134,10 +134,10 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
...
@@ -134,10 +134,10 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
float8x4_t
*
vectorized_out
=
reinterpret_cast
<
float8x4_t
*>
(
out
);
int
const
num_vec_elems
=
num_elems
>>
2
;
int
64_t
const
num_vec_elems
=
num_elems
>>
2
;
#pragma unroll 4
#pragma unroll 4
for
(
int
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
for
(
int
64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
float8x4_t
out_vec
;
...
@@ -153,7 +153,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
...
@@ -153,7 +153,7 @@ __device__ void scaled_fp8_conversion_vec(c10::Float8_e4m3fn* __restrict__ out,
}
}
// Handle the remaining elements if num_elems is not divisible by 4
// Handle the remaining elements if num_elems is not divisible by 4
for
(
int
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
for
(
int
64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
...
...
tests/kernels/test_fp8_quant.py
View file @
fea59c77
...
@@ -60,3 +60,28 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
...
@@ -60,3 +60,28 @@ def test_dynamic_per_tensor_fp8_quant(num_tokens: int, hidden_size: int,
assert
torch
.
allclose
(
ref_scale
,
ops_scale
)
assert
torch
.
allclose
(
ref_scale
,
ops_scale
)
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
assert
torch
.
allclose
(
ref_out
.
to
(
dtype
=
torch
.
float32
),
ops_out
.
to
(
dtype
=
torch
.
float32
))
ops_out
.
to
(
dtype
=
torch
.
float32
))
# Regression test for a case with large activations where an int32 index cannot
# represent the number of elements.
@
torch
.
inference_mode
()
@
pytest
.
mark
.
parametrize
(
"seed"
,
SEEDS
)
def
test_fp8_quant_large
(
seed
:
int
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
num_tokens
=
1024000
# Mistral-Nemo's max_position_embeddings
hidden_size
=
1152
# Smallest hidden_size to reproduce the error
dtype
=
torch
.
bfloat16
x
=
torch
.
rand
(
num_tokens
,
hidden_size
,
dtype
=
dtype
,
device
=
"cuda"
)
ref_out
,
scale
=
ref_dynamic_per_tensor_fp8_quant
(
x
)
ops_out
,
_
=
ops
.
scaled_fp8_quant
(
x
,
scale
)
# Minimize memory footprint in this test by freeing x and upconverting
# the outputs in place. (torch.allclose does not support fp8)
del
x
ref_out
=
ref_out
.
to
(
dtype
=
dtype
)
ops_out
=
ops_out
.
to
(
dtype
=
dtype
)
assert
torch
.
allclose
(
ref_out
,
ops_out
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment