Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
56ebbba3
"vscode:/vscode.git/clone" did not exist on "b1ddea7fd94c52a7be76cec721d32d438681af83"
Commit
56ebbba3
authored
Oct 21, 2025
by
zhuwenwen
Browse files
update linear of RowParallelLinear and UnquantizedLinearMethod apply
parent
ee346e93
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
68 additions
and
12 deletions
+68
-12
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+68
-12
No files found.
vllm/model_executor/layers/linear.py
View file @
56ebbba3
...
@@ -250,7 +250,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -250,7 +250,9 @@ class UnquantizedLinearMethod(LinearMethodBase):
def
apply
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
if
gemm_bank_conf
(
layer
.
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
if
gemm_bank_conf
(
layer
.
weight
.
shape
[
1
]
-
32
)
and
os
.
environ
[
'GEMM_PAD'
]
==
'1'
:
layer
.
weight
=
layer
.
weight
[:,:
-
32
]
layer
.
weight
=
layer
.
weight
[:,:
-
32
]
...
@@ -265,8 +267,42 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -265,8 +267,42 @@ class UnquantizedLinearMethod(LinearMethodBase):
else
:
else
:
if
envs
.
VLLM_USE_NN
and
x
.
shape
[
-
1
]
==
layer
.
weight
.
shape
[
0
]:
if
envs
.
VLLM_USE_NN
and
x
.
shape
[
-
1
]
==
layer
.
weight
.
shape
[
0
]:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
.
t
(),
bias
)
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
.
t
(),
bias
)
else
:
weight
=
layer
.
weight
if
residual
is
not
None
:
assert
output
is
None
or
output
is
residual
if
get_tensor_model_parallel_world_size
(
)
>
1
and
get_tensor_model_parallel_rank
()
!=
0
:
beta
=
0.0
else
:
beta
=
1.0
# optimize cuda memory usage
if
x
.
dim
()
==
2
:
torch
.
addmm
(
residual
,
x
,
weight
.
t
(),
beta
=
beta
,
out
=
residual
)
elif
x
.
dim
()
>=
3
:
hx
=
x
.
size
(
-
1
)
hr
=
residual
.
size
(
-
1
)
torch
.
addmm
(
residual
.
view
(
-
1
,
hr
),
x
.
view
(
-
1
,
hx
),
weight
.
t
(),
beta
=
beta
,
out
=
residual
.
view
(
-
1
,
hr
))
else
:
raise
AssertionError
(
"unrecognized tensor dimensions: {}"
.
format
(
x
.
dim
()))
if
bias
is
not
None
:
residual
+=
bias
return
residual
else
:
if
output
is
not
None
:
if
bias
is
not
None
:
# always separate bias add when output is provided
torch
.
matmul
(
x
,
weight
.
t
(),
out
=
output
)
output
.
add_
(
bias
)
return
output
return
torch
.
matmul
(
x
,
weight
.
t
(),
out
=
output
)
else
:
else
:
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
return
dispatch_unquantized_gemm
()(
x
,
layer
.
weight
,
bias
)
# return dispatch_unquantized_gemm()(x, layer.weight, bias)
class
UnquantizedMoELinearMethod
(
LinearMethodBase
):
class
UnquantizedMoELinearMethod
(
LinearMethodBase
):
...
@@ -633,7 +669,8 @@ class ColumnParallelLinear(LinearBase):
...
@@ -633,7 +669,8 @@ class ColumnParallelLinear(LinearBase):
self
,
input_
,
self
,
input_
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
rms_weight
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
residual
:
Optional
[
torch
.
Tensor
]
=
None
,
update_hd
:
Optional
[
bool
]
=
True
update_hd
:
Optional
[
bool
]
=
True
,
output
:
Optional
[
torch
.
Tensor
]
=
None
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
if
envs
.
USE_FUSED_RMS_QUANT
and
rms_weight
is
not
None
:
input_quant_args
=
None
input_quant_args
=
None
...
@@ -663,7 +700,7 @@ class ColumnParallelLinear(LinearBase):
...
@@ -663,7 +700,7 @@ class ColumnParallelLinear(LinearBase):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
# Matrix multiply.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
,
output
=
output
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
...
@@ -1703,6 +1740,9 @@ class RowParallelLinear(LinearBase):
...
@@ -1703,6 +1740,9 @@ class RowParallelLinear(LinearBase):
def
forward
(
def
forward
(
self
,
input_
,
self
,
input_
,
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
,
use_fused_silu_mul_quant
:
Optional
[
bool
]
=
False
,
residual
=
None
,
output
=
None
,
disable_allreduce
=
False
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
)
->
Union
[
torch
.
Tensor
,
tuple
[
torch
.
Tensor
,
Optional
[
Parameter
]]]:
if
self
.
input_is_parallel
:
if
self
.
input_is_parallel
:
input_parallel
=
input_
input_parallel
=
input_
...
@@ -1712,7 +1752,14 @@ class RowParallelLinear(LinearBase):
...
@@ -1712,7 +1752,14 @@ class RowParallelLinear(LinearBase):
input_
,
num_partitions
=
self
.
tp_size
)
input_
,
num_partitions
=
self
.
tp_size
)
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# only add residual to the first rank
if
residual
is
not
None
and
self
.
tp_size
>
1
and
get_tensor_model_parallel_rank
(
)
!=
0
:
residual
*=
0
# Matrix multiply.
# Matrix multiply.
if
output
is
not
None
:
assert
disable_allreduce
or
not
self
.
reduce_results
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
# Only fuse bias add into GEMM for rank 0 (this ensures that
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
# bias will not get added more than once in TP>1 case)
...
@@ -1728,19 +1775,28 @@ class RowParallelLinear(LinearBase):
...
@@ -1728,19 +1775,28 @@ class RowParallelLinear(LinearBase):
else
:
else
:
output_parallel
=
self
.
quant_method
.
apply
(
self
,
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
input_parallel
,
bias
=
bias_
)
residual
=
residual
,
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output
=
output
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
and
not
disable_allreduce
:
if
envs
.
VLLM_ENABLE_TBO
:
if
envs
.
VLLM_ENABLE_TBO
:
output
=
self
.
tbo_all_reduce
(
output_parallel
)
output
_
=
self
.
tbo_all_reduce
(
output_parallel
)
else
:
else
:
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
output
=
output_parallel
output
_
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
# output_bias = self.bias if self.skip_bias_add else None
# if not self.return_bias:
# return output
if
not
self
.
return_bias
:
return
output
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment