Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
cbb2f59c
"vscode:/vscode.git/clone" did not exist on "61ba33d5802363c09f0acab9e14b7f73abb414c4"
Unverified
Commit
cbb2f59c
authored
Jun 03, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jun 03, 2024
Browse files
[Kernel] Pass a device pointer into the quantize kernel for the scales (#5159)
parent
0ab278ca
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
16 additions
and
11 deletions
+16
-11
csrc/ops.h
csrc/ops.h
+2
-2
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
+9
-6
tests/kernels/test_int8_quant.py
tests/kernels/test_int8_quant.py
+3
-1
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
...d_tensors/schemes/compressed_tensors_w8a8_statictensor.py
+1
-1
No files found.
csrc/ops.h
View file @
cbb2f59c
...
@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
...
@@ -94,8 +94,8 @@ int cutlass_scaled_mm_dq(torch::Tensor& out, torch::Tensor const& a,
#endif
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
float
scale
);
torch
::
Tensor
const
&
scale
);
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
lookup_table
);
...
...
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
View file @
cbb2f59c
...
@@ -28,9 +28,10 @@ namespace vllm {
...
@@ -28,9 +28,10 @@ namespace vllm {
template
<
typename
scalar_t
,
typename
scale_type
>
template
<
typename
scalar_t
,
typename
scale_type
>
__global__
void
static_scaled_int8_quant_kernel
(
__global__
void
static_scaled_int8_quant_kernel
(
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
int8_t
*
__restrict__
out
,
scale_type
scale
,
const
int
hidden_size
)
{
const
scale_type
*
scale
_ptr
,
const
int
hidden_size
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
tid
=
threadIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
scale_type
scale
=
*
scale_ptr
;
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
tid
;
i
<
hidden_size
;
i
+=
blockDim
.
x
)
{
out
[
token_idx
*
hidden_size
+
i
]
=
out
[
token_idx
*
hidden_size
+
i
]
=
...
@@ -40,10 +41,12 @@ __global__ void static_scaled_int8_quant_kernel(
...
@@ -40,10 +41,12 @@ __global__ void static_scaled_int8_quant_kernel(
}
// namespace vllm
}
// namespace vllm
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
const
&
input
,
// [..., hidden_size]
float
scale
)
{
torch
::
Tensor
const
&
scale
)
{
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
input
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
out
.
is_contiguous
());
TORCH_CHECK
(
scale
.
numel
()
==
1
);
int
hidden_size
=
input
.
size
(
-
1
);
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
...
@@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
...
@@ -53,7 +56,7 @@ void static_scaled_int8_quant(torch::Tensor& out, // [..., hidden_size]
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
input
.
scalar_type
(),
"static_scaled_int8_quant_kernel"
,
[
&
]
{
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
vllm
::
static_scaled_int8_quant_kernel
<
scalar_t
,
float
>
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
input
.
data_ptr
<
scalar_t
>
(),
out
.
data_ptr
<
int8_t
>
(),
scale
,
out
.
data_ptr
<
int8_t
>
(),
hidden_size
);
scale
.
data_ptr
<
float
>
(),
hidden_size
);
});
});
}
}
tests/kernels/test_int8_quant.py
View file @
cbb2f59c
...
@@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
...
@@ -26,6 +26,8 @@ def test_quant(num_tokens: int, hidden_size: int, dtype: torch.dtype,
torch
.
iinfo
(
torch
.
int8
).
min
,
torch
.
iinfo
(
torch
.
int8
).
min
,
torch
.
iinfo
(
torch
.
int8
).
max
).
to
(
torch
.
int8
)
torch
.
iinfo
(
torch
.
int8
).
max
).
to
(
torch
.
int8
)
out2
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
out2
=
torch
.
empty_like
(
x
,
dtype
=
torch
.
int8
)
ops
.
static_scaled_int8_quant
(
out2
,
x
,
scale
)
scale_argument
=
torch
.
tensor
([
scale
],
dtype
=
torch
.
float32
,
device
=
"cuda"
)
ops
.
static_scaled_int8_quant
(
out2
,
x
,
scale_argument
)
assert
torch
.
allclose
(
out1
,
out2
,
assert
torch
.
allclose
(
out1
,
out2
,
atol
=
1
)
# big atol to account for rounding errors
atol
=
1
)
# big atol to account for rounding errors
vllm/_custom_ops.py
View file @
cbb2f59c
...
@@ -265,7 +265,7 @@ def scaled_fp8_quant(
...
@@ -265,7 +265,7 @@ def scaled_fp8_quant(
# int8
# int8
def
static_scaled_int8_quant
(
input
:
torch
.
Tensor
,
def
static_scaled_int8_quant
(
input
:
torch
.
Tensor
,
scale
:
float
)
->
torch
.
Tensor
:
scale
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""
"""
Quantize the input tensor to int8 and return the quantized tensor.
Quantize the input tensor to int8 and return the quantized tensor.
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_statictensor.py
View file @
cbb2f59c
...
@@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
...
@@ -97,7 +97,7 @@ class CompressedTensorsW8A8StaticTensor(CompressedTensorsScheme):
act_scale
=
layer
.
input_scale
act_scale
=
layer
.
input_scale
# Input quantize
# Input quantize
x_q
=
custom_ops
.
static_scaled_int8_quant
(
x
,
act_scale
[
0
].
item
()
)
x_q
=
custom_ops
.
static_scaled_int8_quant
(
x
,
act_scale
)
return
custom_ops
.
cutlass_scaled_mm_dq
(
x_q
,
weight
.
t
(),
act_scale
,
return
custom_ops
.
cutlass_scaled_mm_dq
(
x_q
,
weight
.
t
(),
act_scale
,
weight_scale
,
x
.
dtype
)
weight_scale
,
x
.
dtype
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment