Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
c9f9d5b3
Unverified
Commit
c9f9d5b3
authored
Feb 14, 2025
by
Sage Moore
Committed by
GitHub
Feb 14, 2025
Browse files
[Bugfix][AMD] Update torch_bindings so that scaled_fp4_quant isn't build on ROCm (#13235)
parent
0c730268
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
12 additions
and
10 deletions
+12
-10
csrc/ops.h
csrc/ops.h
+4
-4
csrc/torch_bindings.cpp
csrc/torch_bindings.cpp
+7
-6
vllm/_custom_ops.py
vllm/_custom_ops.py
+1
-0
No files found.
csrc/ops.h
View file @
c9f9d5b3
...
...
@@ -177,6 +177,10 @@ void cutlass_scaled_sparse_mm(torch::Tensor& out, torch::Tensor const& a,
std
::
optional
<
torch
::
Tensor
>
const
&
bias
);
std
::
vector
<
torch
::
Tensor
>
cutlass_sparse_compress
(
torch
::
Tensor
const
&
a
);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
#endif
void
static_scaled_int8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
...
...
@@ -194,10 +198,6 @@ torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int64_t
bit
);
void
scaled_fp4_quant
(
torch
::
Tensor
&
output
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
&
output_scale
,
torch
::
Tensor
const
&
input_scale
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
input
,
torch
::
Tensor
const
&
scale
);
...
...
csrc/torch_bindings.cpp
View file @
c9f9d5b3
...
...
@@ -385,6 +385,13 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"bool silu_activation,"
"int pad_slot_id) -> ()"
);
ops
.
impl
(
"causal_conv1d_fwd"
,
torch
::
kCUDA
,
&
causal_conv1d_fwd
);
// Compute NVFP4 block quantized tensor.
ops
.
def
(
"scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale) -> ()"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
#endif
// Quantized GEMM for GPTQ.
...
...
@@ -421,12 +428,6 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops
.
impl
(
"dynamic_per_token_scaled_fp8_quant"
,
torch
::
kCUDA
,
&
dynamic_per_token_scaled_fp8_quant
);
// Compute NVFP4 block quantized tensor.
ops
.
def
(
"scaled_fp4_quant(Tensor! output, Tensor input,"
" Tensor! output_scale, Tensor input_scale) -> ()"
);
ops
.
impl
(
"scaled_fp4_quant"
,
torch
::
kCUDA
,
&
scaled_fp4_quant
);
// Compute int8 quantized tensor for given scaling factor.
ops
.
def
(
"static_scaled_int8_quant(Tensor! result, Tensor input, Tensor scale,"
...
...
vllm/_custom_ops.py
View file @
c9f9d5b3
...
...
@@ -774,6 +774,7 @@ def scaled_fp4_quant(
two values are packed into a uint8 and float8_e4m3 scaling factors
in the sizzled layout.
"""
assert
not
current_platform
.
is_rocm
()
assert
input
.
ndim
>=
1
,
(
f
'input.ndim needs to be >= 1, but got
{
input
.
ndim
}
.'
)
other_dims
=
1
if
input
.
ndim
==
1
else
-
1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment