Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e31446b6
Unverified
Commit
e31446b6
authored
Jun 03, 2025
by
Michael Goin
Committed by
GitHub
Jun 03, 2025
Browse files
[Perf] Tune `scaled_fp8_quant` by increasing vectorization (#18844)
Signed-off-by:
mgoin
<
mgoin64@gmail.com
>
parent
bdf13965
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
115 additions
and
110 deletions
+115
-110
csrc/quantization/fp8/common.cu
csrc/quantization/fp8/common.cu
+19
-16
csrc/quantization/fp8/common.cuh
csrc/quantization/fp8/common.cuh
+35
-33
csrc/quantization/fused_kernels/layernorm_utils.cuh
csrc/quantization/fused_kernels/layernorm_utils.cuh
+50
-49
csrc/quantization/vectorization.cuh
csrc/quantization/vectorization.cuh
+11
-12
No files found.
csrc/quantization/fp8/common.cu
View file @
e31446b6
...
@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -39,8 +39,8 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
fp8_type
*
__restrict__
token_output
=
&
out
[
offset
];
fp8_type
*
__restrict__
token_output
=
&
out
[
offset
];
// For vectorization, token_input and token_output pointers need to be
// For vectorization, token_input and token_output pointers need to be
// aligned at
8
-byte and
4
-byte addresses respectively.
// aligned at
32
-byte and
16
-byte addresses respectively.
bool
const
can_vectorize
=
hidden_size
%
4
==
0
;
bool
const
can_vectorize
=
hidden_size
%
16
==
0
;
float
absmax_val
=
0.0
f
;
float
absmax_val
=
0.0
f
;
if
(
can_vectorize
)
{
if
(
can_vectorize
)
{
...
@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
...
@@ -48,24 +48,24 @@ __global__ void dynamic_per_token_scaled_fp8_quant_kernel(
}
else
{
}
else
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
float
const
x
=
static_cast
<
float
>
(
token_input
[
i
]);
float
const
x
=
static_cast
<
float
>
(
token_input
[
i
]);
absmax_val
=
max
(
absmax_val
,
fabs
(
x
));
absmax_val
=
f
max
f
(
absmax_val
,
fabs
f
(
x
));
}
}
}
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
256
>
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
__shared__
typename
BlockReduce
::
TempStorage
reduceStorage
;
float
const
block_absmax_val_maybe
=
float
const
block_absmax_val_maybe
=
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
BlockReduce
(
reduceStorage
).
Reduce
(
absmax_val
,
cub
::
Max
{},
blockDim
.
x
);
__shared__
float
token_scale
;
__shared__
float
token_scale
;
if
(
tid
==
0
)
{
if
(
tid
==
0
)
{
if
(
scale_ub
)
{
if
(
scale_ub
)
{
token_scale
=
min
(
block_absmax_val_maybe
,
*
scale_ub
);
token_scale
=
f
min
f
(
block_absmax_val_maybe
,
*
scale_ub
);
}
else
{
}
else
{
token_scale
=
block_absmax_val_maybe
;
token_scale
=
block_absmax_val_maybe
;
}
}
// token scale computation
// token scale computation
token_scale
=
max
(
token_scale
/
quant_type_max_v
<
fp8_type
>
,
token_scale
=
f
max
f
(
token_scale
/
quant_type_max_v
<
fp8_type
>
,
min_scaling_factor
<
fp8_type
>::
val
());
min_scaling_factor
<
fp8_type
>::
val
());
scale
[
token_idx
]
=
token_scale
;
scale
[
token_idx
]
=
token_scale
;
}
}
__syncthreads
();
__syncthreads
();
...
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -88,10 +88,11 @@ void static_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
scale
)
// [1]
torch
::
Tensor
const
&
scale
)
// [1]
{
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
const
block_size
=
256
;
int64_t
num_elems
=
input
.
numel
();
int
const
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
int
const
num_elems
=
input
.
numel
();
dim3
block
(
1024
);
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
block_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
...
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
...
@@ -110,10 +111,11 @@ void dynamic_scaled_fp8_quant(torch::Tensor& out, // [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
const
&
input
,
// [..., d]
torch
::
Tensor
&
scale
)
// [1]
torch
::
Tensor
&
scale
)
// [1]
{
{
int64_t
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
int
const
block_size
=
256
;
int64_t
num_elems
=
input
.
numel
();
int
const
num_tokens
=
input
.
numel
()
/
input
.
size
(
-
1
);
dim3
grid
(
num_tokens
);
int
const
num_elems
=
input
.
numel
();
dim3
block
(
1024
);
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
block_size
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
...
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
...
@@ -141,8 +143,9 @@ void dynamic_per_token_scaled_fp8_quant(
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
hidden_size
=
input
.
size
(
-
1
);
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
const
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
const
block_size
=
256
;
dim3
const
grid
(
num_tokens
);
dim3
const
grid
(
num_tokens
);
dim3
const
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
const
block
(
std
::
min
(
hidden_size
,
block_size
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
...
...
csrc/quantization/fp8/common.cuh
View file @
e31446b6
...
@@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
...
@@ -46,7 +46,7 @@ __device__ __forceinline__ fp8_type scaled_fp8_conversion(float const val,
}
}
float
r
=
float
r
=
fmax
(
-
quant_type_max_v
<
fp8_type
>
,
fmin
(
x
,
quant_type_max_v
<
fp8_type
>
));
fmax
f
(
-
quant_type_max_v
<
fp8_type
>
,
fmin
f
(
x
,
quant_type_max_v
<
fp8_type
>
));
#ifndef USE_ROCM
#ifndef USE_ROCM
return
static_cast
<
fp8_type
>
(
r
);
return
static_cast
<
fp8_type
>
(
r
);
#else
#else
...
@@ -65,7 +65,7 @@ template <typename scalar_t, typename fp8_type>
...
@@ -65,7 +65,7 @@ template <typename scalar_t, typename fp8_type>
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
__global__
void
segmented_max_reduction
(
float
*
__restrict__
scale
,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
int64_t
num_elems
)
{
int64_t
num_elems
)
{
__shared__
float
cache
[
1024
];
__shared__
float
cache
[
256
];
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
int64_t
i
=
blockDim
.
x
*
blockIdx
.
x
+
threadIdx
.
x
;
// First store maximum for all values processes by
// First store maximum for all values processes by
...
@@ -73,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
...
@@ -73,7 +73,7 @@ __global__ void segmented_max_reduction(float* __restrict__ scale,
scalar_t
tmp
=
0.0
;
scalar_t
tmp
=
0.0
;
while
(
i
<
num_elems
)
{
while
(
i
<
num_elems
)
{
float
x
=
static_cast
<
float
>
(
input
[
i
]);
float
x
=
static_cast
<
float
>
(
input
[
i
]);
tmp
=
max
(
tmp
,
fabs
(
x
));
tmp
=
f
max
f
(
tmp
,
fabs
f
(
x
));
i
+=
blockDim
.
x
*
gridDim
.
x
;
i
+=
blockDim
.
x
*
gridDim
.
x
;
}
}
cache
[
threadIdx
.
x
]
=
tmp
;
cache
[
threadIdx
.
x
]
=
tmp
;
...
@@ -100,25 +100,27 @@ template <typename scalar_t>
...
@@ -100,25 +100,27 @@ template <typename scalar_t>
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
__device__
float
thread_max_vec
(
scalar_t
const
*
__restrict__
input
,
int64_t
const
num_elems
,
int
const
tid
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
int
const
step
)
{
constexpr
size_t
VEC_SIZE
=
16
;
using
scalarxN_t
=
vec_n_t
<
scalar_t
,
VEC_SIZE
>
;
// Vectorized input/output to better utilize memory bandwidth.
// Vectorized input/output to better utilize memory bandwidth.
vec4_t
<
scalar_t
>
const
*
vectorized_in
=
auto
const
*
vectorized_in
=
reinterpret_cast
<
scalarxN_t
const
*>
(
input
);
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
// num_elems / VEC_SIZE (which is 16)
int64_t
const
num_vec_elems
=
num_elems
>>
4
;
float
absmax_val
=
0.0
f
;
float
absmax_val
=
0.0
f
;
#pragma unroll
4
#pragma unroll
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
scalar
xN
_t
in_vec
=
vectorized_in
[
i
];
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
x
));
#pragma unroll
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
y
));
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
z
));
absmax_val
=
f
max
f
(
absmax_val
,
fabs
f
(
in_vec
.
val
[
j
]
));
absmax_val
=
max
(
absmax_val
,
fabs
(
in_vec
.
w
));
}
}
}
// Handle the remaining elements if num_elems is not divisible by
4
// Handle the remaining elements if num_elems is not divisible by
VEC_SIZE
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
for
(
int64_t
i
=
num_vec_elems
*
VEC_SIZE
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
absmax_val
=
max
(
absmax_val
,
fabs
(
input
[
i
]));
absmax_val
=
f
max
f
(
absmax_val
,
fabs
f
(
input
[
i
]));
}
}
return
absmax_val
;
return
absmax_val
;
...
@@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
...
@@ -130,31 +132,31 @@ __device__ void scaled_fp8_conversion_vec(fp8_type* __restrict__ out,
float
const
scale
,
float
const
scale
,
int64_t
const
num_elems
,
int64_t
const
num_elems
,
int
const
tid
,
int
const
step
)
{
int
const
tid
,
int
const
step
)
{
using
float8x4_t
=
q8x4_t
<
fp8_type
>
;
constexpr
size_t
VEC_SIZE
=
16
;
using
scalarxN_t
=
vec_n_t
<
scalar_t
,
VEC_SIZE
>
;
using
float8xN_t
=
q8_n_t
<
fp8_type
,
VEC_SIZE
>
;
// Vectorized input/output to better utilize memory bandwidth.
// Vectorized input/output to better utilize memory bandwidth.
auto
const
*
vectorized_in
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
input
);
auto
const
*
vectorized_in
=
reinterpret_cast
<
scalar
xN
_t
const
*>
(
input
);
auto
*
vectorized_out
=
reinterpret_cast
<
float8x
4
_t
*>
(
out
);
auto
*
vectorized_out
=
reinterpret_cast
<
float8x
N
_t
*>
(
out
);
int64_t
const
num_vec_elems
=
num_elems
>>
2
;
// num_elems / VEC_SIZE (which is 16)
int64_t
const
num_vec_elems
=
num_elems
>>
4
;
#pragma unroll
4
#pragma unroll
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
for
(
int64_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
step
)
{
vec4_t
<
scalar_t
>
in_vec
=
vectorized_in
[
i
];
scalarxN_t
in_vec
=
vectorized_in
[
i
];
float8x4_t
out_vec
;
float8xN_t
out_vec
;
out_vec
.
x
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
#pragma unroll
static_cast
<
float
>
(
in_vec
.
x
),
scale
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
out_vec
.
y
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
out_vec
.
val
[
j
]
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
y
),
scale
);
static_cast
<
float
>
(
in_vec
.
val
[
j
]),
scale
);
out_vec
.
z
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
}
static_cast
<
float
>
(
in_vec
.
z
),
scale
);
out_vec
.
w
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
in_vec
.
w
),
scale
);
vectorized_out
[
i
]
=
out_vec
;
vectorized_out
[
i
]
=
out_vec
;
}
}
// Handle the remaining elements if num_elems is not divisible by
4
// Handle the remaining elements if num_elems is not divisible by
VEC_SIZE
for
(
int64_t
i
=
num_vec_elems
*
4
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
for
(
int64_t
i
=
num_vec_elems
*
VEC_SIZE
+
tid
;
i
<
num_elems
;
i
+=
step
)
{
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
out
[
i
]
=
scaled_fp8_conversion
<
is_scale_inverted
,
fp8_type
>
(
static_cast
<
float
>
(
input
[
i
]),
scale
);
static_cast
<
float
>
(
input
[
i
]),
scale
);
}
}
...
...
csrc/quantization/fused_kernels/layernorm_utils.cuh
View file @
e31446b6
...
@@ -140,6 +140,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
...
@@ -140,6 +140,7 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
// sum of squares
// sum of squares
float
ss
=
0.0
f
;
float
ss
=
0.0
f
;
const
int
VEC_SIZE
=
4
;
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
#pragma unroll 4
#pragma unroll 4
...
@@ -147,22 +148,23 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
...
@@ -147,22 +148,23 @@ __device__ void compute_rms(float* rms, scalar_t const* __restrict__ input,
vec4_t
<
scalar_t
>
in
=
vec_input
[
i
];
vec4_t
<
scalar_t
>
in
=
vec_input
[
i
];
vec4_t
<
float
>
x
;
vec4_t
<
float
>
x
;
x
.
x
=
static_cast
<
float
>
(
in
.
x
);
#pragma unroll
x
.
y
=
static_cast
<
float
>
(
in
.
y
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
x
.
z
=
static_cast
<
float
>
(
in
.
z
);
x
.
val
[
j
]
=
static_cast
<
float
>
(
in
.
val
[
j
]);
x
.
w
=
static_cast
<
float
>
(
in
.
w
);
}
if
constexpr
(
has_residual
)
{
if
constexpr
(
has_residual
)
{
vec4_t
<
scalar_t
>
r
=
vec_residual
[
i
];
vec4_t
<
scalar_t
>
r
=
vec_residual
[
i
];
x
.
x
+=
static_cast
<
float
>
(
r
.
x
);
#pragma unroll
x
.
y
+=
static_cast
<
float
>
(
r
.
y
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
x
.
z
+=
static_cast
<
float
>
(
r
.
z
);
x
.
val
[
j
]
+=
static_cast
<
float
>
(
r
.
val
[
j
]
);
x
.
w
+=
static_cast
<
float
>
(
r
.
w
);
}
}
}
ss
+=
x
.
x
*
x
.
x
;
#pragma unroll
ss
+=
x
.
y
*
x
.
y
;
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
ss
+=
x
.
z
*
x
.
z
;
ss
+=
x
.
val
[
j
]
*
x
.
val
[
j
]
;
ss
+=
x
.
w
*
x
.
w
;
}
}
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
...
@@ -203,6 +205,7 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -203,6 +205,7 @@ __device__ void compute_dynamic_per_token_scales(
constexpr
scalar_out_t
qmax
{
quant_type_max_v
<
scalar_out_t
>
};
constexpr
scalar_out_t
qmax
{
quant_type_max_v
<
scalar_out_t
>
};
const
int
VEC_SIZE
=
4
;
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
float
block_absmax_val_maybe
=
0.0
f
;
float
block_absmax_val_maybe
=
0.0
f
;
...
@@ -212,26 +215,25 @@ __device__ void compute_dynamic_per_token_scales(
...
@@ -212,26 +215,25 @@ __device__ void compute_dynamic_per_token_scales(
vec4_t
<
scalar_t
>
const
w
=
vec_weight
[
i
];
vec4_t
<
scalar_t
>
const
w
=
vec_weight
[
i
];
vec4_t
<
float
>
x
;
vec4_t
<
float
>
x
;
x
.
x
=
static_cast
<
float
>
(
in
.
x
);
#pragma unroll
x
.
y
=
static_cast
<
float
>
(
in
.
y
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
x
.
z
=
static_cast
<
float
>
(
in
.
z
);
x
.
val
[
j
]
=
static_cast
<
float
>
(
in
.
val
[
j
]);
x
.
w
=
static_cast
<
float
>
(
in
.
w
);
}
if
constexpr
(
has_residual
)
{
if
constexpr
(
has_residual
)
{
vec4_t
<
scalar_t
>
r
=
vec_residual
[
i
];
vec4_t
<
scalar_t
>
r
=
vec_residual
[
i
];
x
.
x
+=
static_cast
<
float
>
(
r
.
x
);
#pragma unroll
x
.
y
+=
static_cast
<
float
>
(
r
.
y
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
x
.
z
+=
static_cast
<
float
>
(
r
.
z
);
x
.
val
[
j
]
+=
static_cast
<
float
>
(
r
.
val
[
j
]
);
x
.
w
+=
static_cast
<
float
>
(
r
.
w
);
}
}
}
block_absmax_val_maybe
=
fmaxf
(
#pragma unroll
block_absmax_val_maybe
,
fabs
(
static_cast
<
scalar_t
>
(
x
.
x
*
rms
)
*
w
.
x
));
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
block_absmax_val_maybe
=
fmaxf
(
block_absmax_val_maybe
=
block_absmax_val_maybe
,
fabs
(
static_cast
<
scalar_t
>
(
x
.
y
*
rms
)
*
w
.
y
));
fmaxf
(
block_absmax_val_maybe
,
block_absmax_val_maybe
=
fmaxf
(
fabs
(
static_cast
<
scalar_t
>
(
x
.
val
[
j
]
*
rms
)
*
w
.
val
[
j
]));
block_absmax_val_maybe
,
fabs
(
static_cast
<
scalar_t
>
(
x
.
z
*
rms
)
*
w
.
z
));
}
block_absmax_val_maybe
=
fmaxf
(
block_absmax_val_maybe
,
fabs
(
static_cast
<
scalar_t
>
(
x
.
w
*
rms
)
*
w
.
w
));
}
}
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
using
BlockReduce
=
cub
::
BlockReduce
<
float
,
1024
>
;
...
@@ -282,6 +284,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
...
@@ -282,6 +284,7 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
vec_residual
=
reinterpret_cast
<
vec4_t
<
scalar_t
>*>
(
&
residual
[
token_offset
]);
vec_residual
=
reinterpret_cast
<
vec4_t
<
scalar_t
>*>
(
&
residual
[
token_offset
]);
}
}
const
int
VEC_SIZE
=
4
;
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
int32_t
const
num_vec_elems
=
hidden_size
>>
2
;
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
// TODO(luka/varun) extract into type-agnostic vectorized quant function to
...
@@ -292,33 +295,31 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
...
@@ -292,33 +295,31 @@ __device__ void norm_and_quant(scalar_out_t* __restrict__ output,
vec4_t
<
scalar_t
>
const
w
=
vec_weight
[
i
];
vec4_t
<
scalar_t
>
const
w
=
vec_weight
[
i
];
vec4_t
<
float
>
x
;
vec4_t
<
float
>
x
;
x
.
x
=
static_cast
<
float
>
(
in
.
x
);
#pragma unroll
x
.
y
=
static_cast
<
float
>
(
in
.
y
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
x
.
z
=
static_cast
<
float
>
(
in
.
z
);
x
.
val
[
j
]
=
static_cast
<
float
>
(
in
.
val
[
j
]);
x
.
w
=
static_cast
<
float
>
(
in
.
w
);
}
if
constexpr
(
has_residual
)
{
if
constexpr
(
has_residual
)
{
vec4_t
<
scalar_t
>
r
=
vec_residual
[
i
];
vec4_t
<
scalar_t
>
r
=
vec_residual
[
i
];
x
.
x
+=
static_cast
<
float
>
(
r
.
x
);
#pragma unroll
x
.
y
+=
static_cast
<
float
>
(
r
.
y
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
x
.
z
+=
static_cast
<
float
>
(
r
.
z
);
x
.
val
[
j
]
+=
static_cast
<
float
>
(
r
.
val
[
j
]
);
x
.
w
+=
static_cast
<
float
>
(
r
.
w
);
}
// Update residual
// Update residual
r
.
x
=
static_cast
<
scalar_t
>
(
x
.
x
);
#pragma unroll
r
.
y
=
static_cast
<
scalar_t
>
(
x
.
y
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
r
.
z
=
static_cast
<
scalar_t
>
(
x
.
z
);
r
.
val
[
j
]
=
static_cast
<
scalar_t
>
(
x
.
val
[
j
]
);
r
.
w
=
static_cast
<
scalar_t
>
(
x
.
w
);
}
vec_residual
[
i
]
=
r
;
vec_residual
[
i
]
=
r
;
}
}
q8x4_t
<
scalar_out_t
>
out
;
q8x4_t
<
scalar_out_t
>
out
;
out
.
x
=
ScaledQuant
<
scalar_out_t
,
is_scale_inverted
>::
quant_fn
(
#pragma unroll
static_cast
<
scalar_t
>
(
x
.
x
*
rms
)
*
w
.
x
,
scale
);
for
(
int
j
=
0
;
j
<
VEC_SIZE
;
++
j
)
{
out
.
y
=
ScaledQuant
<
scalar_out_t
,
is_scale_inverted
>::
quant_fn
(
out
.
val
[
j
]
=
ScaledQuant
<
scalar_out_t
,
is_scale_inverted
>::
quant_fn
(
static_cast
<
scalar_t
>
(
x
.
y
*
rms
)
*
w
.
y
,
scale
);
static_cast
<
scalar_t
>
(
x
.
val
[
j
]
*
rms
)
*
w
.
val
[
j
],
scale
);
out
.
z
=
ScaledQuant
<
scalar_out_t
,
is_scale_inverted
>::
quant_fn
(
}
static_cast
<
scalar_t
>
(
x
.
z
*
rms
)
*
w
.
z
,
scale
);
out
.
w
=
ScaledQuant
<
scalar_out_t
,
is_scale_inverted
>::
quant_fn
(
static_cast
<
scalar_t
>
(
x
.
w
*
rms
)
*
w
.
w
,
scale
);
vec_output
[
i
]
=
out
;
vec_output
[
i
]
=
out
;
}
}
}
}
...
...
csrc/quantization/vectorization.cuh
View file @
e31446b6
...
@@ -10,23 +10,22 @@
...
@@ -10,23 +10,22 @@
namespace
vllm
{
namespace
vllm
{
// Vectorization containers
// Vectorization containers
template
<
typename
scalar_t
>
template
<
typename
scalar_t
,
size_t
vec_size
>
struct
__align__
(
8
)
vec4_t
{
struct
__align__
(
vec_size
*
sizeof
(
scalar_t
))
vec_n_t
{
scalar_t
x
;
scalar_t
val
[
vec_size
];
scalar_t
y
;
scalar_t
z
;
scalar_t
w
;
};
};
template
<
typename
quant_type_t
>
template
<
typename
quant_type_t
,
size_t
vec_size
>
struct
__align__
(
4
)
q8
x4
_t
{
struct
__align__
(
vec_size
*
sizeof
(
quant_type_t
)
)
q8
_n
_t
{
static_assert
(
std
::
is_same_v
<
quant_type_t
,
int8_t
>
||
static_assert
(
std
::
is_same_v
<
quant_type_t
,
int8_t
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fnuz
>
);
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fnuz
>
);
quant_type_t
x
;
quant_type_t
val
[
vec_size
];
quant_type_t
y
;
quant_type_t
z
;
quant_type_t
w
;
};
};
template
<
typename
scalar_t
>
using
vec4_t
=
vec_n_t
<
scalar_t
,
4
>
;
template
<
typename
quant_type_t
>
using
q8x4_t
=
q8_n_t
<
quant_type_t
,
4
>
;
}
// namespace vllm
}
// namespace vllm
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment