Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
5d9d15e7
Unverified
Commit
5d9d15e7
authored
Jan 25, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 25, 2025
Browse files
support fp32 in sampling_scaling_penalties kernel (#3121)
parent
665e5e85
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
5 deletions
+26
-5
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
+1
-2
sgl-kernel/src/sgl-kernel/csrc/utils.h
sgl-kernel/src/sgl-kernel/csrc/utils.h
+18
-0
sgl-kernel/tests/test_sampling_scaling_penalties.py
sgl-kernel/tests/test_sampling_scaling_penalties.py
+7
-3
No files found.
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
View file @
5d9d15e7
#include <ATen/ATen.h>
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <pytorch_extension_utils.h>
#include <THC/THCAtomics.cuh>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include <flashinfer/vec_dtypes.cuh>
...
@@ -49,7 +48,7 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
...
@@ -49,7 +48,7 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FP16
(
logits
.
scalar_type
(),
scalar_t
,
[
&
]
{
DISPATCH_PYTORCH_DTYPE_TO_CTYPE_
FLOAT_
FP16
(
logits
.
scalar_type
(),
scalar_t
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
const
int
blocks
=
(
numel
+
threads
*
vec_size
-
1
)
/
(
threads
*
vec_size
);
const
int
blocks
=
(
numel
+
threads
*
vec_size
-
1
)
/
(
threads
*
vec_size
);
sampling_scaling_penalties_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
sampling_scaling_penalties_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
...
...
sgl-kernel/src/sgl-kernel/csrc/utils.h
View file @
5d9d15e7
#pragma once
#pragma once
#include <pytorch_extension_utils.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <sstream>
#include <sstream>
...
@@ -44,3 +45,20 @@ inline int getSMVersion() {
...
@@ -44,3 +45,20 @@ inline int getSMVersion() {
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
sm_minor
,
cudaDevAttrComputeCapabilityMinor
,
device
));
CHECK_CUDA_SUCCESS
(
cudaDeviceGetAttribute
(
&
sm_minor
,
cudaDevAttrComputeCapabilityMinor
,
device
));
return
sm_major
*
10
+
sm_minor
;
return
sm_major
*
10
+
sm_minor
;
}
}
#define DISPATCH_PYTORCH_DTYPE_TO_CTYPE_FLOAT_FP16(pytorch_dtype, c_type, ...) \
[&]() -> bool { \
switch (pytorch_dtype) { \
case at::ScalarType::Float: { \
using c_type = float; \
return __VA_ARGS__(); \
} \
_DISPATCH_CASE_F16(c_type, __VA_ARGS__) \
_DISPATCH_CASE_BF16(c_type, __VA_ARGS__) \
default: \
std::ostringstream oss; \
oss << __PRETTY_FUNCTION__ << " failed to dispatch data type " << pytorch_dtype; \
TORCH_CHECK(false, oss.str()); \
return false; \
} \
}()
sgl-kernel/tests/test_sampling_scaling_penalties.py
View file @
5d9d15e7
...
@@ -2,10 +2,14 @@ import pytest
...
@@ -2,10 +2,14 @@ import pytest
import
torch
import
torch
from
sgl_kernel
import
sampling_scaling_penalties
from
sgl_kernel
import
sampling_scaling_penalties
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
]
vocab_sizes
=
[
2048
,
4096
,
8192
,
16384
,
32768
,
32767
]
dtypes
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
2048
,
4096
,
8192
,
16384
,
32768
,
32767
])
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
batch_sizes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
vocab_sizes
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
dtypes
)
def
test_sampling_scaling_penalties
(
batch_size
,
vocab_size
,
dtype
):
def
test_sampling_scaling_penalties
(
batch_size
,
vocab_size
,
dtype
):
device
=
torch
.
device
(
"cuda"
)
device
=
torch
.
device
(
"cuda"
)
rtol
=
1e-3
rtol
=
1e-3
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment