Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
54271bb7
Unverified
Commit
54271bb7
authored
Apr 26, 2025
by
Charlie Fu
Committed by
GitHub
Apr 25, 2025
Browse files
[ROCm][Misc] Follow-ups for Skinny Gemms on ROCm. (#17011)
Signed-off-by:
charlifu
<
charlifu@amd.com
>
parent
9e96f56e
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
18 additions
and
15 deletions
+18
-15
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
+4
-3
vllm/model_executor/layers/utils.py
vllm/model_executor/layers/utils.py
+4
-3
vllm/model_executor/layers/vocab_parallel_embedding.py
vllm/model_executor/layers/vocab_parallel_embedding.py
+2
-1
vllm/platforms/rocm.py
vllm/platforms/rocm.py
+8
-8
No files found.
vllm/model_executor/layers/quantization/utils/w8a8_utils.py
View file @
54271bb7
...
@@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
...
@@ -155,8 +155,9 @@ def rocm_per_tensor_w8a8_scaled_mm(*, qinput: torch.Tensor,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
scale_b
:
torch
.
Tensor
,
bias
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
input_2d
:
torch
.
Tensor
,
output_shape
:
List
)
->
torch
.
Tensor
:
output_shape
:
List
)
->
torch
.
Tensor
:
if
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
qinput
.
shape
[
from
vllm.platforms.rocm
import
on_mi250_mi300
0
]
==
1
and
qinput
.
shape
[
1
]
%
16
==
0
:
if
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
not
on_mi250_mi300
(
)
and
qinput
.
shape
[
0
]
==
1
and
qinput
.
shape
[
1
]
%
16
==
0
:
output
=
ops
.
wvSplitKQ
(
weight
.
t
(),
qinput
,
out_dtype
,
scale_a
,
scale_b
,
output
=
ops
.
wvSplitKQ
(
weight
.
t
(),
qinput
,
out_dtype
,
scale_a
,
scale_b
,
current_platform
.
get_cu_count
())
current_platform
.
get_cu_count
())
else
:
else
:
...
@@ -371,7 +372,7 @@ class Fp8LinearOp:
...
@@ -371,7 +372,7 @@ class Fp8LinearOp:
return
w8a8_scaled_mm_func
(
qinput
=
qinput
,
return
w8a8_scaled_mm_func
(
qinput
=
qinput
,
weight
=
weight
,
weight
=
weight
,
out_dtype
=
inp
ut
.
dtype
,
out_dtype
=
o
ut
_
dtype
,
scale_a
=
x_scale
,
scale_a
=
x_scale
,
scale_b
=
weight_scale
,
scale_b
=
weight_scale
,
bias
=
bias
,
bias
=
bias
,
...
...
vllm/model_executor/layers/utils.py
View file @
54271bb7
...
@@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
...
@@ -70,8 +70,9 @@ def apply_penalties(logits: torch.Tensor, prompt_tokens_tensor: torch.Tensor,
def
rocm_unquantized_gemm
(
x
:
torch
.
Tensor
,
def
rocm_unquantized_gemm
(
x
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
weight
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
bias
:
Optional
[
torch
.
Tensor
]
=
None
):
from
vllm.platforms.rocm
import
on_mi250_mi300
k
=
weight
.
shape
[
1
]
k
=
weight
.
shape
[
1
]
use_skinny
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
\
use_skinny
=
(
envs
.
VLLM_ROCM_USE_SKINNY_GEMM
and
on_mi250_mi300
()
and
\
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
\
x
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]
\
and
k
%
8
==
0
and
bias
is
None
)
and
k
%
8
==
0
and
bias
is
None
)
...
@@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
...
@@ -83,11 +84,11 @@ def rocm_unquantized_gemm(x: torch.Tensor,
m
=
weight
.
shape
[
0
]
m
=
weight
.
shape
[
0
]
cu_count
=
current_platform
.
get_cu_count
()
cu_count
=
current_platform
.
get_cu_count
()
if
m
>
8
and
n
<
4
:
if
m
>
8
and
0
<
n
<
4
:
out
=
ops
.
wvSplitK
(
weight
,
x_view
,
cu_count
)
out
=
ops
.
wvSplitK
(
weight
,
x_view
,
cu_count
)
return
out
.
view
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
return
out
.
view
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
elif
m
%
4
==
0
and
n
==
1
and
k
<=
8192
:
elif
m
%
4
==
0
and
n
==
1
and
k
<=
8192
:
out
=
ops
.
LLMM1
(
weight
,
x_view
,
out
,
4
)
out
=
ops
.
LLMM1
(
weight
,
x_view
,
4
)
return
out
.
view
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
return
out
.
view
(
*
x
.
shape
[:
-
1
],
weight
.
shape
[
0
])
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
return
torch
.
nn
.
functional
.
linear
(
x
,
weight
,
bias
)
...
...
vllm/model_executor/layers/vocab_parallel_embedding.py
View file @
54271bb7
...
@@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -12,6 +12,7 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
QuantizationConfig
,
QuantizeMethodBase
,
method_has_implemented_embedding
)
from
vllm.model_executor.layers.utils
import
dispatch_unquantized_gemm
from
vllm.model_executor.parameter
import
BasevLLMParameter
from
vllm.model_executor.parameter
import
BasevLLMParameter
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.platforms
import
current_platform
from
vllm.platforms
import
current_platform
...
@@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
...
@@ -40,7 +41,7 @@ class UnquantizedEmbeddingMethod(QuantizeMethodBase):
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
return
F
.
linear
(
x
,
layer
.
weight
,
bias
)
return
dispatch_unquantized_gemm
()
(
x
,
layer
.
weight
,
bias
)
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
def
embedding
(
self
,
layer
:
torch
.
nn
.
Module
,
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
input_
:
torch
.
Tensor
)
->
torch
.
Tensor
:
...
...
vllm/platforms/rocm.py
View file @
54271bb7
...
@@ -98,21 +98,21 @@ def device_id_to_physical_device_id(device_id: int) -> int:
...
@@ -98,21 +98,21 @@ def device_id_to_physical_device_id(device_id: int) -> int:
return
device_id
return
device_id
def
on_mi250_mi300
()
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
@
cache
@
cache
def
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
def
use_rocm_custom_paged_attention
(
qtype
:
torch
.
dtype
,
head_size
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
block_size
:
int
,
gqa_ratio
:
int
,
max_seq_len
:
int
,
max_seq_len
:
int
,
sliding_window
:
int
)
->
bool
:
sliding_window
:
int
)
->
bool
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
# rocm custom page attention not support on gfx1*
ON_NAVI
=
"gfx1"
in
GPU_ARCH
ON_MI250_MI300
=
any
(
arch
in
GPU_ARCH
for
arch
in
[
"gfx90a"
,
"gfx942"
])
# rocm custom page attention not support on navi (gfx1*)
# custom paged attn always supported on V0. On V1, requires sliding window
# custom paged attn always supported on V0. On V1, requires sliding window
# disabled due to observed numerical discrepancy.
# disabled due to observed numerical discrepancy.
return
(
ON_MI250_MI300
and
not
ON_NAVI
return
(
on_mi250_mi300
()
and
(
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
and
(
not
envs
.
VLLM_USE_V1
or
sliding_window
==
0
or
sliding_window
==
(
-
1
,
-
1
))
or
sliding_window
==
(
-
1
,
-
1
))
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
qtype
==
torch
.
half
or
qtype
==
torch
.
bfloat16
)
and
(
head_size
==
64
or
head_size
==
128
)
and
(
head_size
==
64
or
head_size
==
128
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment