Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
f1b68618
"gpu/gpu_info_nvml.c" did not exist on "b0135f4b9b176eab9155b660d04c9ca2a1ec2341"
Unverified
Commit
f1b68618
authored
Jan 23, 2025
by
Xiaoyu Zhang
Committed by
GitHub
Jan 23, 2025
Browse files
use flashinfer vec_dtypes in sgl_kernel (#3083)
parent
0da0989a
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
51 additions
and
80 deletions
+51
-80
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
+25
-22
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
+0
-29
sgl-kernel/tests/test_sampling_scaling_penalties.py
sgl-kernel/tests/test_sampling_scaling_penalties.py
+26
-29
No files found.
sgl-kernel/src/sgl-kernel/csrc/sampling_scaling_penalties.cu
View file @
f1b68618
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <pytorch_extension_utils.h>
#include <THC/THCAtomics.cuh>
#include <flashinfer/vec_dtypes.cuh>
#include "utils.h"
#include "vectorization.cuh"
template
<
typename
scalar_t
>
__global__
void
sampling_scaling_penalties_kernel
(
const
scalar_t
*
logits
,
const
scalar_t
*
scaling_penalties
,
...
...
@@ -13,31 +14,31 @@ __global__ void sampling_scaling_penalties_kernel(const scalar_t* logits, const
const
int32_t
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
const
int32_t
stride
=
blockDim
.
x
*
gridDim
.
x
;
auto
const
*
vectorized_logits
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
logits
);
auto
const
*
vectorized_penalties
=
reinterpret_cast
<
vec4_t
<
scalar_t
>
const
*>
(
scaling_penalties
);
auto
*
vectorized_output
=
reinterpret_cast
<
vec4_t
<
scalar_t
>*>
(
output
);
constexpr
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
using
vec_t
=
flashinfer
::
vec_t
<
scalar_t
,
vec_size
>
;
const
int32_t
num_vec_elems
=
numel
>>
2
;
const
int32_t
num_vec_elems
=
numel
/
vec_size
;
#pragma unroll
4
#pragma unroll
1
for
(
int32_t
i
=
tid
;
i
<
num_vec_elems
;
i
+=
stride
)
{
vec
4
_t
<
scalar_t
>
logits_vec
=
vectorized_logits
[
i
]
;
vec4_t
<
scalar_t
>
penalties_vec
=
vectorized_penalties
[
i
]
;
vec4_t
<
scalar_t
>
out_vec
;
vec_t
logits_vec
,
penalties_vec
,
out_vec
;
logits_vec
.
cast_load
(
logits
+
i
*
vec_size
)
;
penalties_vec
.
cast_load
(
scaling_penalties
+
i
*
vec_size
)
;
out_vec
.
x
=
logits_vec
.
x
>
0
?
logits_vec
.
x
/
penalties_vec
.
x
:
logits_vec
.
x
*
penalties_vec
.
x
;
out_vec
.
y
=
logits_vec
.
y
>
0
?
logits_vec
.
y
/
penalties_vec
.
y
:
logits_vec
.
y
*
penalties_vec
.
y
;
out_vec
.
z
=
logits_vec
.
z
>
0
?
logits_vec
.
z
/
penalties_vec
.
z
:
logits_vec
.
z
*
penalties_vec
.
z
;
out_vec
.
w
=
logits_vec
.
w
>
0
?
logits_vec
.
w
/
penalties_vec
.
w
:
logits_vec
.
w
*
penalties_vec
.
w
;
#pragma unroll
for
(
uint32_t
j
=
0
;
j
<
vec_size
;
++
j
)
{
out_vec
[
j
]
=
logits_vec
[
j
]
>
scalar_t
(
0.0
f
)
?
logits_vec
[
j
]
/
penalties_vec
[
j
]
:
logits_vec
[
j
]
*
penalties_vec
[
j
]
;
}
vectorized_output
[
i
]
=
out_vec
;
out_vec
.
cast_store
(
output
+
i
*
vec_size
)
;
}
const
int32_t
start_idx
=
num_vec_elems
*
4
;
// process the remaining elements
const
int32_t
start_idx
=
num_vec_elems
*
vec_size
;
for
(
int32_t
i
=
start_idx
+
tid
;
i
<
numel
;
i
+=
stride
)
{
scalar_t
logit
=
logits
[
i
];
scalar_t
penalty
=
scaling_penalties
[
i
];
output
[
i
]
=
logit
>
0
?
logit
/
penalty
:
logit
*
penalty
;
output
[
i
]
=
logit
>
scalar_t
(
0.0
f
)
?
logit
/
penalty
:
logit
*
penalty
;
}
}
...
...
@@ -48,12 +49,14 @@ torch::Tensor sampling_scaling_penalties(const torch::Tensor& logits, const torc
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_
DISPATCH_
FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
logits
.
scalar_type
(),
"sampling_scaling_penalties_kernel"
,
([
&
]
{
const
int
blocks
=
(
numel
+
threads
*
4
-
1
)
/
(
threads
*
4
);
DISPATCH_
PYTORCH_DTYPE_TO_CTYPE_FP16
(
logits
.
scalar_type
(),
scalar_t
,
[
&
]
{
uint32_t
vec_size
=
16
/
sizeof
(
scalar_t
);
const
int
blocks
=
(
numel
+
threads
*
vec_size
-
1
)
/
(
threads
*
vec_size
);
sampling_scaling_penalties_kernel
<
scalar_t
><<<
blocks
,
threads
,
0
,
stream
>>>
(
logits
.
data_ptr
<
scalar_t
>
(),
scaling_penalties
.
data_ptr
<
scalar_t
>
(),
output
.
data_ptr
<
scalar_t
>
(),
numel
);
}));
static_cast
<
scalar_t
*>
(
logits
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
scaling_penalties
.
data_ptr
()),
static_cast
<
scalar_t
*>
(
output
.
data_ptr
()),
numel
);
return
true
;
});
return
output
;
}
sgl-kernel/src/sgl-kernel/csrc/vectorization.cuh
deleted
100644 → 0
View file @
0da0989a
// Adapted from https://github.com/vllm-project/vllm/blob/main/csrc/quantization/vectorization.cuh
#pragma once
/**
* __device__ datatypes vectorized by 4
*/
// Include both AMD and NVIDIA fp8 types to avoid circular import
// TODO(luka/varun) use FP8_TYPE instead after refactoring
#include <c10/util/Float8_e4m3fn.h>
#include <c10/util/Float8_e4m3fnuz.h>
// Vectorization containers
template
<
typename
scalar_t
>
struct
__align__
(
8
)
vec4_t
{
scalar_t
x
;
scalar_t
y
;
scalar_t
z
;
scalar_t
w
;
};
template
<
typename
quant_type_t
>
struct
__align__
(
4
)
q8x4_t
{
static_assert
(
std
::
is_same_v
<
quant_type_t
,
int8_t
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fn
>
||
std
::
is_same_v
<
quant_type_t
,
c10
::
Float8_e4m3fnuz
>
);
quant_type_t
x
;
quant_type_t
y
;
quant_type_t
z
;
quant_type_t
w
;
}
;
sgl-kernel/tests/test_sampling_scaling_penalties.py
View file @
f1b68618
import
pytest
import
torch
from
sgl_kernel
import
sampling_scaling_penalties
def
test_sampling_scaling_penalties
():
batch_sizes
=
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
]
vocab_sizes
=
[
2048
,
4096
,
8192
,
16384
,
32768
,
32767
]
dtypes
=
[
torch
.
float32
,
torch
.
half
,
torch
.
bfloat16
]
@
pytest
.
mark
.
parametrize
(
"batch_size"
,
[
1
,
2
,
4
,
8
,
16
,
32
,
64
,
65
])
@
pytest
.
mark
.
parametrize
(
"vocab_size"
,
[
2048
,
4096
,
8192
,
16
384
,
32
768
,
32767
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
half
,
torch
.
bfloat16
])
def
test_sampling_scaling_penalties
(
batch_size
,
vocab_size
,
dtype
):
device
=
torch
.
device
(
"cuda"
)
for
dtype
in
dtypes
:
rtol
=
1e-3
atol
=
1e-3
for
bs
in
batch_sizes
:
for
vocab_size
in
vocab_sizes
:
logits
=
torch
.
randn
(
bs
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
logits
=
torch
.
randn
(
batch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
scaling_penalties
=
(
torch
.
rand
(
b
s
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
torch
.
rand
(
b
atch_size
,
vocab_size
,
device
=
device
,
dtype
=
dtype
)
+
0.5
)
ref_output
=
torch
.
where
(
...
...
@@ -30,7 +27,7 @@ def test_sampling_scaling_penalties():
ref_output
,
rtol
=
rtol
,
atol
=
atol
,
msg
=
f
"Failed for batch_size=
{
b
s
}
, vocab_size=
{
vocab_size
}
, dtype=
{
dtype
}
"
,
msg
=
f
"Failed for batch_size=
{
b
atch_size
}
, vocab_size=
{
vocab_size
}
, dtype=
{
dtype
}
"
,
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment