Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
a5314e86
Unverified
Commit
a5314e86
authored
Jul 19, 2024
by
Thomas Parnell
Committed by
GitHub
Jul 19, 2024
Browse files
[Model] RowParallelLinear: pass bias to quant_method.apply (#6327)
Signed-off-by:
Thomas Parnell
<
tpa@zurich.ibm.com
>
parent
a921e863
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
14 additions
and
9 deletions
+14
-9
tests/spec_decode/e2e/test_integration_dist_tp2.py
tests/spec_decode/e2e/test_integration_dist_tp2.py
+3
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+11
-9
No files found.
tests/spec_decode/e2e/test_integration_dist_tp2.py
View file @
a5314e86
...
@@ -83,6 +83,9 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
...
@@ -83,6 +83,9 @@ def test_target_model_tp_gt_1(baseline_llm_generator, test_llm_generator,
# cleaned up properly, and its server host thread leaks, causing the
# cleaned up properly, and its server host thread leaks, causing the
# second run of the test to fail with internal NCCL error.
# second run of the test to fail with internal NCCL error.
"use_async"
:
True
,
"use_async"
:
True
,
# precision
"dtype"
:
"float32"
,
}])
}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
"baseline_llm_kwargs"
,
[{}])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
...
...
vllm/model_executor/layers/linear.py
View file @
a5314e86
...
@@ -715,6 +715,7 @@ class RowParallelLinear(LinearBase):
...
@@ -715,6 +715,7 @@ class RowParallelLinear(LinearBase):
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
...
@@ -770,18 +771,19 @@ class RowParallelLinear(LinearBase):
...
@@ -770,18 +771,19 @@ class RowParallelLinear(LinearBase):
# Matrix multiply.
# Matrix multiply.
assert
self
.
quant_method
is
not
None
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
# Only fuse bias add into GEMM for rank 0 (this ensures that
# bias will not get added more than once in TP>1 case)
bias_
=
None
if
(
self
.
tp_rank
>
0
or
self
.
skip_bias_add
)
else
self
.
bias
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
,
bias
=
bias_
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output
_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
output_
=
output_parallel
output
=
output_parallel
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output_bias
=
None
else
:
output
=
output_
output_bias
=
self
.
bias
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
def
extra_repr
(
self
)
->
str
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment