Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f1b68618
Unverified
Commit
f1b68618
authored
Jan 23, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 23, 2025
Browse files
use flashinfer vec_dtypes in sgl_kernel (#3083)
parent
0da0989a
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
80 deletions
+51
-80
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
+25
-22
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
+0
-29
sgl-kernel/tests/test_sampling_scaling_penalties.py
sgl-kernel/tests/test_sampling_scaling_penalties.py
+26
-29
No files found.
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
View file @
f1b68618
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <pytorch_extension_utils.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#include "utils.h"
#include "vectorization.cuh"
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
sampling_scaling_penalties_kernel
(
const
scalar_t
*
logits
,
const
scalar_t
*
scaling_penalties
,
__global__
void
sampling_scaling_penalties_kernel
(
const
scalar_t
*
logits
,
const
scalar_t
*
scaling_penalties
,
...
@@ -13,31 +14,31 @@ __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const
...
@@ -13,31 +14,31 @@ __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const
const
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int32_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
const
int32_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
auto
const
*
vectorized_logits
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
logits
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
auto
const
*
vectorized_penalties
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
scaling_penalties
);
using
vec_t
=
flashinfer
::
vec_t
<
scalar_t
,
vec_size
>
;
auto
*
vectorized_output
=
reinterpret_cast
<
vec4_t
<
scalar_t
>*>
(
output
);
const
int32_t
num_vec_elems
=
numel
>>
2
;
const
int32_t
num_vec_elems
=
numel
/
vec_size
;
#pragma unroll
4
#pragma unroll
1
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
stride
)
{
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
stride
)
{
vec
4
_t
<
scalar_t
>
logits_vec
=
vectorized_logits
[
i
]
;
vec_t
logits_vec
,
penalties_vec
,
out_vec
;
vec4_t
<
scalar_t
>
penalties_vec
=
vectorized_penalties
[
i
]
;
logits_vec
.
cast_load
(
logits
+
i
*
vec_size
)
;
vec4_t
<
scalar_t
>
out_vec
;
penalties_vec
.
cast_load
(
scaling_penalties
+
i
*
vec_size
)
;
out_vec
.
x
=
logits_vec
.
x
>
0
?
logits_vec
.
x
/
penalties_vec
.
x
:
logits_vec
.
x
*
penalties_vec
.
x
;
#pragma unroll
out_vec
.
y
=
logits_vec
.
y
>
0
?
logits_vec
.
y
/
penalties_vec
.
y
:
logits_vec
.
y
*
penalties_vec
.
y
;
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
out_vec
.
z
=
logits_vec
.
z
>
0
?
logits_vec
.
z
/
penalties_vec
.
z
:
logits_vec
.
z
*
penalties_vec
.
z
;
out_vec
[
j
]
=
logits_vec
[
j
]
>
scalar_t
(
0.0
f
)
?
logits_vec
[
j
]
/
penalties_vec
[
j
]
:
logits_vec
[
j
]
*
penalties_vec
[
j
]
;
out_vec
.
w
=
logits_vec
.
w
>
0
?
logits_vec
.
w
/
penalties_vec
.
w
:
logits_vec
.
w
*
penalties_vec
.
w
;
}
vectorized_output
[
i
]
=
out_vec
;
out_vec
.
cast_store
(
output
+
i
*
vec_size
)
;
}
}
const
int32_t
start_idx
=
num_vec_elems
*
4
;
// process the remaining elements
const
int32_t
start_idx
=
num_vec_elems
*
vec_size
;
for
(
int32_t
i
=
start_idx
+
tid
;
i
<
numel
;
i
+=
stride
)
{
for
(
int32_t
i
=
start_idx
+
tid
;
i
<
numel
;
i
+=
stride
)
{
scalar_t
logit
=
logits
[
i
];
scalar_t
logit
=
logits
[
i
];
scalar_t
penalty
=
scaling_penalties
[
i
];
scalar_t
penalty
=
scaling_penalties
[
i
];
output
[
i
]
=
logit
>
0
?
logit
/
penalty
:
logit
*
penalty
;
output
[
i
]
=
logit
>
scalar_t
(
0.0
f
)
?
logit
/
penalty
:
logit
*
penalty
;
}
}
}
}
...
@@ -48,12 +49,14 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
...
@@ -48,12 +49,14 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
logits
.
scalar_type
(),
scalar_t
,
[
&
]
{
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
logits
.
scalar_type
(),
"sampling_scaling_penalties_kernel"
,
([
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
const
int
blocks
=
(
numel
+
threads
*
4
-
1
)
/
(
threads
*
4
);
const
int
blocks
=
(
numel
+
threads
*
vec_size
-
1
)
/
(
threads
*
vec_size
);
sampling_scaling_penalties_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
sampling_scaling_penalties_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
logits
.
data_ptr
<
scalar_t
>
(),
scaling_penalties
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
numel
);
static_cast
<
scalar_t
*>
(
logits
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
scaling_penalties
.
data_ptr
()),
}));
static_cast
<
scalar_t
*>
(
output
.
data_ptr
()),
numel
);
return
true
;
});
return
output
;
return
output
;
}
}
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
deleted
100644 → 0
View file @
0da0989a
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh
#pragma once
/**
* __device__ datatypes vectorized by 4
*/
// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
// Vectorization containers
template
<
typename
scalar_t
>
struct
__align__
(
8
)
vec4_t
{
scalar_t
x
;
scalar_t
y
;
scalar_t
z
;
scalar_t
w
;
};
template
<
typename
quant_type_t
>
struct
__align__
(
4
)
q8x4_t
{
static_assert
(
std
::
is_same_v
<
quant_type_t
,
int8_t
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fnuz
>
);
quant_type_t
x
;
quant_type_t
y
;
quant_type_t
z
;
quant_type_t
w
;
}
;
sgl-kernel/tests/test_sampling_scaling_penalties.py
View file @
f1b68618
import
pytest
import
torch
import
torch
from
sgl_kernel
import
sampling_scaling_penalties
from
sgl_kernel
import
sampling_scaling_penalties
def
test_sampling_scaling_penalties
():
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
])
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
]
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
2048
,
4096
,
8192
,
16
384
,
32
768
,
32767
])
vocab_sizes
=
[
2048
,
4096
,
8192
,
16384
,
32768
,
32767
]
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
dtypes
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
def
test_sampling_scaling_penalties
(
batch_size
,
vocab_size
,
dtype
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
rtol
=
1e-3
for
dtype
in
dtypes
:
atol
=
1e-3
rtol
=
1e-3
atol
=
1e-3
logits
=
torch
.
randn
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
for
bs
in
batch_sizes
:
torch
.
rand
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
for
vocab_size
in
vocab_sizes
:
)
logits
=
torch
.
randn
(
bs
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
ref_output
=
torch
.
where
(
torch
.
rand
(
bs
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
logits
>
0
,
logits
/
scaling_penalties
,
logits
*
scaling_penalties
)
)
ref_output
=
torch
.
where
(
kernel_output
=
sampling_scaling_penalties
(
logits
,
scaling_penalties
)
logits
>
0
,
logits
/
scaling_penalties
,
logits
*
scaling_penalties
)
torch
.
testing
.
assert_close
(
kernel_output
,
kernel_output
=
sampling_scaling_penalties
(
logits
,
scaling_penalties
)
ref_output
,
rtol
=
rtol
,
torch
.
testing
.
assert_close
(
atol
=
atol
,
kernel_output
,
msg
=
f
"Failed for batch_size=
{
batch_size
}
, vocab_size=
{
vocab_size
}
, dtype=
{
dtype
}
"
,
ref_output
,
)
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"Failed for batch_size=
{
bs
}
, vocab_size=
{
vocab_size
}
, dtype=
{
dtype
}
"
,
)
if
__name__
==
"__main__"
:
if
__name__
==
"__main__"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment