Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
3f3b6b21
"tests/vscode:/vscode.git/clone" did not exist on "b95cc5014dc7b260e5c70ae33d1b30c54d11306d"
Unverified
Commit
3f3b6b21
authored
Jun 20, 2024
by
Tyler Michael Smith
Committed by
GitHub
Jun 20, 2024
Browse files
[Bugfix] Fix the CUDA version check for FP8 support in the CUTLASS kernels (#5715)
parent
a7dcc620
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
30 additions
and
13 deletions
+30
-13
csrc/ops.h
csrc/ops.h
+2
-0
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
+16
-0
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+6
-0
vllm/_custom_ops.py
vllm/_custom_ops.py
+4
-0
vllm/model_executor/layers/quantization/fp8.py
vllm/model_executor/layers/quantization/fp8.py
+2
-13
No files found.
csrc/ops.h
View file @
3f3b6b21
...
@@ -92,6 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
...
@@ -92,6 +92,8 @@ torch::Tensor gptq_marlin_repack(torch::Tensor& b_q_weight, torch::Tensor& perm,
int64_t
size_k
,
int64_t
size_n
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
int64_t
num_bits
);
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
);
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
);
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_entry.cu
View file @
3f3b6b21
...
@@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
...
@@ -25,6 +25,22 @@ void cutlass_scaled_mm_sm90(torch::Tensor& c, torch::Tensor const& a,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
const
&
b_scales
);
#endif
#endif
bool
cutlass_scaled_mm_supports_fp8
(
int64_t
cuda_device_capability
)
{
// CUTLASS FP8 kernels need at least
// CUDA 12.0 on SM90 systems (Hopper)
// CUDA 12.4 on SM89 systems (Lovelace)
#if defined CUDA_VERSION
if
(
cuda_device_capability
>=
90
)
{
return
CUDA_VERSION
>=
12000
;
}
else
if
(
cuda_device_capability
>=
89
)
{
return
CUDA_VERSION
>=
12040
;
}
#endif
return
false
;
}
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
void
cutlass_scaled_mm
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
torch
::
Tensor
const
&
b_scales
)
{
...
...
csrc/torch_bindings.cpp
View file @
3f3b6b21
...
@@ -144,6 +144,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
...
@@ -144,6 +144,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
" Tensor b, Tensor a_scales,"
" Tensor b, Tensor a_scales,"
" Tensor b_scales) -> ()"
);
" Tensor b_scales) -> ()"
);
ops
.
impl
(
"cutlass_scaled_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_mm
);
ops
.
impl
(
"cutlass_scaled_mm"
,
torch
::
kCUDA
,
&
cutlass_scaled_mm
);
// Check if cutlass scaled_mm is supported for CUDA devices of the given
// capability
ops
.
def
(
"cutlass_scaled_mm_supports_fp8"
,
&
cutlass_scaled_mm_supports_fp8
);
ops
.
impl
(
"cutlass_scaled_mm_supports_fp8"
,
torch
::
kCUDA
,
&
cutlass_scaled_mm_supports_fp8
);
#endif
#endif
// Quantized GEMM for GPTQ.
// Quantized GEMM for GPTQ.
...
...
vllm/_custom_ops.py
View file @
3f3b6b21
...
@@ -216,6 +216,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
...
@@ -216,6 +216,10 @@ def gptq_marlin_24_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
# cutlass
# cutlass
def
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
:
int
)
->
bool
:
return
torch
.
ops
.
_C
.
cutlass_scaled_mm_supports_fp8
(
cuda_device_capability
)
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
def
cutlass_scaled_mm
(
a
:
torch
.
Tensor
,
b
:
torch
.
Tensor
,
scale_a
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
out_dtype
:
Type
[
torch
.
dtype
])
->
torch
.
Tensor
:
...
...
vllm/model_executor/layers/quantization/fp8.py
View file @
3f3b6b21
...
@@ -20,19 +20,8 @@ logger = init_logger(__name__)
...
@@ -20,19 +20,8 @@ logger = init_logger(__name__)
def
cutlass_fp8_supported
()
->
bool
:
def
cutlass_fp8_supported
()
->
bool
:
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
torch
.
cuda
.
get_device_capability
()
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
capability
=
capability
[
0
]
*
10
+
capability
[
1
]
major
,
minor
=
torch
.
version
.
cuda
.
split
(
"."
)
version
=
int
(
major
)
*
10
+
int
(
minor
)
return
ops
.
cutlass_scaled_mm_supports_fp8
(
capability
)
# CUTLASS FP8 kernels need at least
# CUDA 12.0 on SM90 systems (Hopper)
# CUDA 12.4 on SM89 systems (Lovelace)
gpu_is_supported
=
False
if
capability
>=
90
:
gpu_is_supported
=
version
>
120
elif
capability
>=
89
:
gpu_is_supported
=
version
>
124
return
gpu_is_supported
class
Fp8Config
(
QuantizationConfig
):
class
Fp8Config
(
QuantizationConfig
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment