Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
ba0bfd40
Unverified
Commit
ba0bfd40
authored
Oct 02, 2023
by
Zhuohan Li
Committed by
GitHub
Oct 02, 2023
Browse files
TP/quantization/weight loading refactor part 1 - Simplify parallel linear logic (#1181)
parent
84e4e37d
Changes
41
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
169 additions
and
77 deletions
+169
-77
.github/workflows/pylint.yml
.github/workflows/pylint.yml
+1
-1
.github/workflows/yapf.yml
.github/workflows/yapf.yml
+1
-1
.pylintrc
.pylintrc
+1
-1
format.sh
format.sh
+2
-3
tests/async_engine/api_server_async_engine.py
tests/async_engine/api_server_async_engine.py
+1
-0
tests/async_engine/test_api_server.py
tests/async_engine/test_api_server.py
+3
-0
tests/async_engine/test_async_llm_engine.py
tests/async_engine/test_async_llm_engine.py
+2
-2
tests/async_engine/test_request_tracker.py
tests/async_engine/test_request_tracker.py
+11
-11
tests/conftest.py
tests/conftest.py
+1
-0
tests/distributed/test_comm_ops.py
tests/distributed/test_comm_ops.py
+82
-0
tests/engine/test_detokenize.py
tests/engine/test_detokenize.py
+1
-0
tests/kernels/test_activation.py
tests/kernels/test_activation.py
+6
-6
tests/kernels/test_cache.py
tests/kernels/test_cache.py
+3
-3
tests/kernels/test_pos_encoding.py
tests/kernels/test_pos_encoding.py
+1
-1
tests/samplers/test_sampler.py
tests/samplers/test_sampler.py
+2
-1
vllm/model_executor/layers/quantized_linear/__init__.py
vllm/model_executor/layers/quantized_linear/__init__.py
+2
-2
vllm/model_executor/layers/quantized_linear/awq.py
vllm/model_executor/layers/quantized_linear/awq.py
+2
-2
vllm/model_executor/layers/sampler.py
vllm/model_executor/layers/sampler.py
+3
-3
vllm/model_executor/models/aquila.py
vllm/model_executor/models/aquila.py
+22
-20
vllm/model_executor/models/baichuan.py
vllm/model_executor/models/baichuan.py
+22
-20
No files found.
.github/workflows/pylint.yml
View file @
ba0bfd40
...
...
@@ -28,4 +28,4 @@ jobs:
pip install pylint==2.8.2
-
name
:
Analysing the code with pylint
run
:
|
pylint vllm
pylint vllm
tests
.github/workflows/yapf.yml
View file @
ba0bfd40
...
...
@@ -28,4 +28,4 @@ jobs:
pip install toml==0.10.2
-
name
:
Running yapf
run
:
|
yapf --diff --recursive vllm
--exclude 'vllm/model_executor/parallel_utils/**'
yapf --diff --recursive vllm
tests
.pylintrc
View file @
ba0bfd40
...
...
@@ -8,7 +8,7 @@
[MASTER]
# Files or directories to be skipped. They should be base names, not paths.
ignore=docs
,parallel_utils
ignore=docs
# Files or directories matching the regex patterns are skipped. The regex
# matches against base names, not paths.
...
...
format.sh
View file @
ba0bfd40
...
...
@@ -44,7 +44,6 @@ YAPF_FLAGS=(
YAPF_EXCLUDES
=(
'--exclude'
'build/**'
'--exclude'
'vllm/model_executor/parallel_utils/**'
)
# Format specified files
...
...
@@ -72,7 +71,7 @@ format_changed() {
# Format all files
format_all
()
{
yapf
--in-place
"
${
YAPF_FLAGS
[@]
}
"
"
${
YAPF_EXCLUDES
[@]
}
"
vllm
yapf
--in-place
"
${
YAPF_FLAGS
[@]
}
"
"
${
YAPF_EXCLUDES
[@]
}
"
vllm
tests
}
## This flag formats individual files. --files *must* be the first command line
...
...
@@ -96,7 +95,7 @@ echo 'vLLM yapf: Done'
# Run Pylint
echo
'vLLM Pylint:'
pylint vllm
pylint vllm
tests
if
!
git diff
--quiet
&>/dev/null
;
then
echo
'Reformatted files. Please review and stage the changes.'
...
...
tests/async_engine/api_server_async_engine.py
View file @
ba0bfd40
...
...
@@ -14,6 +14,7 @@ app = vllm.entrypoints.api_server.app
class
AsyncLLMEngineWithStats
(
AsyncLLMEngine
):
# pylint: disable=redefined-outer-name
def
__init__
(
self
,
*
args
,
**
kwargs
):
super
().
__init__
(
*
args
,
**
kwargs
)
self
.
_num_aborts
=
0
...
...
tests/async_engine/test_api_server.py
View file @
ba0bfd40
...
...
@@ -24,6 +24,7 @@ def _query_server(prompt: str) -> dict:
def
api_server
():
script_path
=
Path
(
__file__
).
parent
.
joinpath
(
"api_server_async_engine.py"
).
absolute
()
# pylint: disable=consider-using-with
uvicorn_process
=
subprocess
.
Popen
([
sys
.
executable
,
"-u"
,
str
(
script_path
),
"--model"
,
"facebook/opt-125m"
...
...
@@ -32,6 +33,7 @@ def api_server():
uvicorn_process
.
terminate
()
# pylint: disable=redefined-outer-name, unused-argument
def
test_api_server
(
api_server
):
"""
Run the API server and test it.
...
...
@@ -47,6 +49,7 @@ def test_api_server(api_server):
prompts
=
[
"Hello world"
]
*
1
result
=
None
while
not
result
:
# pylint: disable=bare-except
try
:
for
result
in
pool
.
map
(
_query_server
,
prompts
):
break
...
...
tests/async_engine/test_async_llm_engine.py
View file @
ba0bfd40
...
...
@@ -32,12 +32,12 @@ class MockEngine:
self
.
request_id
=
None
def
add_request
(
self
,
**
kwargs
):
del
kwargs
# Unused
self
.
add_request_calls
+=
1
return
def
abort_request
(
self
,
request_id
):
del
request_id
# Unused
self
.
abort_request_calls
+=
1
return
class
MockAsyncLLMEngine
(
AsyncLLMEngine
):
...
...
tests/async_engine/test_request_tracker.py
View file @
ba0bfd40
...
...
@@ -7,22 +7,22 @@ from vllm.outputs import RequestOutput
class
DummyEvent
:
def
__init__
(
self
):
self
.
_
flag
=
False
self
.
flag
=
False
def
set
(
self
):
self
.
_
flag
=
True
self
.
flag
=
True
def
clear
(
self
):
self
.
_
flag
=
False
self
.
flag
=
False
def
test_request_tracker
():
tracker
=
RequestTracker
()
tracker
.
new_requests_event
=
DummyEvent
()
stream_1
=
tracker
.
add_request
(
"1"
)
assert
tracker
.
new_requests_event
.
_
flag
assert
tracker
.
new_requests_event
.
flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_
flag
assert
not
tracker
.
new_requests_event
.
flag
assert
len
(
new
)
==
1
assert
new
[
0
][
"request_id"
]
==
"1"
assert
not
finished
...
...
@@ -30,9 +30,9 @@ def test_request_tracker():
stream_2
=
tracker
.
add_request
(
"2"
)
stream_3
=
tracker
.
add_request
(
"3"
)
assert
tracker
.
new_requests_event
.
_
flag
assert
tracker
.
new_requests_event
.
flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_
flag
assert
not
tracker
.
new_requests_event
.
flag
assert
len
(
new
)
==
2
assert
new
[
0
][
"request_id"
]
==
"2"
assert
new
[
1
][
"request_id"
]
==
"3"
...
...
@@ -43,7 +43,7 @@ def test_request_tracker():
# request_ids must be unique
with
pytest
.
raises
(
KeyError
):
tracker
.
add_request
(
"1"
)
assert
not
tracker
.
new_requests_event
.
_
flag
assert
not
tracker
.
new_requests_event
.
flag
tracker
.
abort_request
(
"1"
)
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
...
...
@@ -54,7 +54,7 @@ def test_request_tracker():
stream_4
=
tracker
.
add_request
(
"4"
)
tracker
.
abort_request
(
"4"
)
assert
tracker
.
new_requests_event
.
_
flag
assert
tracker
.
new_requests_event
.
flag
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
len
(
finished
)
==
1
assert
"4"
in
finished
...
...
@@ -62,11 +62,11 @@ def test_request_tracker():
assert
stream_4
.
finished
stream_5
=
tracker
.
add_request
(
"5"
)
assert
tracker
.
new_requests_event
.
_
flag
assert
tracker
.
new_requests_event
.
flag
tracker
.
process_request_output
(
RequestOutput
(
"2"
,
"output"
,
[],
[],
finished
=
True
))
new
,
finished
=
tracker
.
get_new_and_finished_requests
()
assert
not
tracker
.
new_requests_event
.
_
flag
assert
not
tracker
.
new_requests_event
.
flag
assert
len
(
finished
)
==
1
assert
"2"
in
finished
assert
len
(
new
)
==
1
...
...
tests/conftest.py
View file @
ba0bfd40
...
...
@@ -8,6 +8,7 @@ from vllm import LLM, SamplingParams
from
vllm.transformers_utils.tokenizer
import
get_tokenizer
_TEST_PROMPTS
=
[
# pylint: disable=line-too-long
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs."
,
"Briefly describe the major milestones in the development of artificial intelligence from 1950 to 2020."
,
"Compare and contrast artificial intelligence with human intelligence in terms of processing information."
,
...
...
tests/distributed/test_comm_ops.py
0 → 100644
View file @
ba0bfd40
"""Test the communication operators.
Run `pytest tests/distributed/test_comm_ops.py --forked`.
"""
from
multiprocessing
import
Process
import
pytest
import
torch
from
vllm.config
import
ParallelConfig
from
vllm.engine.ray_utils
import
get_open_port
from
vllm.model_executor.parallel_utils.communication_op
import
(
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_gather
,
)
from
vllm.worker.worker
import
_init_distributed_environment
def
init_test_distributed_environment
(
pipeline_parallel_size
:
int
,
tensor_parallel_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
parallel_config
=
ParallelConfig
(
pipeline_parallel_size
,
tensor_parallel_size
,
worker_use_ray
=
True
)
distributed_init_method
=
f
"tcp://localhost:
{
distributed_init_port
}
"
torch
.
cuda
.
set_device
(
rank
)
_init_distributed_environment
(
parallel_config
,
rank
,
distributed_init_method
)
def
all_reduce_test_worker
(
tensor_parallel_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
distributed_init_port
)
num_elements
=
8
all_tensors
=
[
torch
.
arange
(
num_elements
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
*
(
r
+
1
)
for
r
in
range
(
tensor_parallel_size
)
]
expected
=
torch
.
sum
(
torch
.
stack
(
all_tensors
,
dim
=
0
),
dim
=
0
)
t
=
all_tensors
[
rank
]
t
=
tensor_model_parallel_all_reduce
(
t
)
assert
torch
.
allclose
(
t
,
expected
)
def
all_gather_test_worker
(
tensor_parallel_size
:
int
,
rank
:
int
,
distributed_init_port
:
str
):
init_test_distributed_environment
(
1
,
tensor_parallel_size
,
rank
,
distributed_init_port
)
num_dimensions
=
3
tensor_size
=
list
(
range
(
2
,
num_dimensions
+
2
))
total_size
=
1
for
s
in
tensor_size
:
total_size
*=
s
for
all_gather_dimension
in
range
(
num_dimensions
):
all_tensors
=
[
torch
.
arange
(
total_size
,
dtype
=
torch
.
float32
,
device
=
"cuda"
).
reshape
(
tensor_size
)
*
(
r
+
1
)
for
r
in
range
(
tensor_parallel_size
)
]
expected
=
torch
.
cat
(
all_tensors
,
dim
=
all_gather_dimension
)
t
=
all_tensors
[
rank
]
t
=
tensor_model_parallel_all_gather
(
t
,
all_gather_dimension
)
assert
torch
.
allclose
(
t
,
expected
)
@
pytest
.
mark
.
skipif
(
torch
.
cuda
.
device_count
()
<
2
,
reason
=
"Need at least 2 GPUs to run the test."
)
@
pytest
.
mark
.
parametrize
(
"tensor_parallel_size"
,
[
2
])
@
pytest
.
mark
.
parametrize
(
"test_target"
,
[
all_reduce_test_worker
,
all_gather_test_worker
])
def
test_multi_process_tensor_parallel
(
tensor_parallel_size
,
test_target
):
distributed_init_port
=
get_open_port
()
processes
=
[]
for
rank
in
range
(
tensor_parallel_size
):
p
=
Process
(
target
=
test_target
,
args
=
(
tensor_parallel_size
,
rank
,
distributed_init_port
))
p
.
start
()
processes
.
append
(
p
)
for
p
in
processes
:
p
.
join
()
assert
all
(
p
.
exitcode
==
0
for
p
in
processes
)
tests/engine/test_detokenize.py
View file @
ba0bfd40
...
...
@@ -5,6 +5,7 @@ from transformers import AutoTokenizer
from
vllm.transformers_utils.tokenizer
import
detokenize_incrementally
TRUTH
=
[
# pylint: disable=line-too-long
"Hello here, this is a simple test"
,
"vLLM is a high-throughput and memory-efficient inference and serving engine for LLMs. It is designed to be used in production environments, where inference and serving"
,
"我很感谢你的热情"
...
...
tests/kernels/test_activation.py
View file @
ba0bfd40
...
...
@@ -29,8 +29,8 @@ def test_silu_and_mul(
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
'
cuda
'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'
cuda
'
)
x
=
torch
.
randn
(
num_tokens
,
2
*
d
,
dtype
=
dtype
,
device
=
"
cuda
"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"
cuda
"
)
activation_ops
.
silu_and_mul
(
out
,
x
)
ref_out
=
ref_silu_and_mul
(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
...
...
@@ -49,8 +49,8 @@ def test_gelu_new(
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'
cuda
'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'
cuda
'
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"
cuda
"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"
cuda
"
)
activation_ops
.
gelu_new
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_new"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
...
...
@@ -68,8 +68,8 @@ def test_gelu_fast(
)
->
None
:
torch
.
random
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'
cuda
'
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
'
cuda
'
)
x
=
torch
.
randn
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"
cuda
"
)
out
=
torch
.
empty
(
num_tokens
,
d
,
dtype
=
dtype
,
device
=
"
cuda
"
)
activation_ops
.
gelu_fast
(
out
,
x
)
ref_out
=
get_activation
(
"gelu_fast"
)(
x
)
assert
torch
.
allclose
(
out
,
ref_out
,
atol
=
1e-5
,
rtol
=
1e-5
)
tests/kernels/test_cache.py
View file @
ba0bfd40
...
...
@@ -106,14 +106,14 @@ def test_reshape_and_cache(
# Create a random slot mapping.
num_slots
=
block_size
*
num_blocks
slot_mapping
=
random
.
sample
(
range
(
num_slots
),
num_tokens
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
'
cuda
'
)
slot_mapping
=
torch
.
tensor
(
slot_mapping
,
dtype
=
torch
.
int
,
device
=
"
cuda
"
)
qkv
=
torch
.
randn
(
num_tokens
,
3
,
num_heads
,
head_size
,
dtype
=
dtype
,
device
=
'
cuda
'
)
device
=
"
cuda
"
)
_
,
key
,
value
=
qkv
.
unbind
(
dim
=
1
)
# Create the KV caches.
...
...
@@ -132,7 +132,7 @@ def test_reshape_and_cache(
# Run the reference implementation.
reshaped_key
=
key
.
reshape
(
num_tokens
,
*
key_cache
[
0
,
:,
:,
0
,
:].
shape
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
'
floor
'
)
block_indicies
=
torch
.
div
(
slot_mapping
,
block_size
,
rounding_mode
=
"
floor
"
)
block_indicies
=
block_indicies
.
cpu
().
tolist
()
block_offsets
=
slot_mapping
%
block_size
block_offsets
=
block_offsets
.
cpu
().
tolist
()
...
...
tests/kernels/test_pos_encoding.py
View file @
ba0bfd40
...
...
@@ -140,7 +140,7 @@ def test_rotary_embedding(
cos
=
freqs
.
cos
()
sin
=
freqs
.
sin
()
cos_sin_cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
'
cuda
'
)
cos_sin_cache
=
cos_sin_cache
.
to
(
dtype
=
dtype
,
device
=
"
cuda
"
)
# Run the kernel. The kernel is in-place, so we need to clone the inputs.
out_query
=
query
.
clone
()
...
...
tests/samplers/test_sampler.py
View file @
ba0bfd40
# pylint: disable=protected-access
import
pytest
import
random
from
typing
import
Tuple
...
...
@@ -108,7 +109,7 @@ def test_sampler_all_random(seed: int):
def
test_sampler_all_beam
(
seed
:
int
):
set_random_seed
(
seed
)
batch_size
=
random
.
randint
(
1
,
256
)
input_tensor
,
fake_logits
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
input_tensor
,
_
,
sampler
,
worker
=
_prepare_test
(
batch_size
)
seq_group_metadata_list
=
[]
for
i
in
range
(
batch_size
):
...
...
vllm/model_executor/layers/quantized_linear/__init__.py
View file @
ba0bfd40
from
vllm.model_executor.layers.quantized_linear.awq
import
(
AWQColumnParallelLinear
,
AWQRowParallelLinear
)
from
vllm.model_executor.parallel_utils.
tensor_parallel
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.
layers
import
(
ColumnParallelLinear
,
RowParallelLinear
)
_QUANTIZED_LINEAR_REGISTRY
=
{
"awq"
:
(
AWQColumnParallelLinear
,
AWQRowParallelLinear
),
...
...
vllm/model_executor/layers/quantized_linear/awq.py
View file @
ba0bfd40
...
...
@@ -4,8 +4,8 @@ import torch
from
torch.nn.parameter
import
Parameter
from
vllm
import
quantization_ops
from
vllm.model_executor.parallel_utils.
tensor_parallel.layers
import
(
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.
layers
import
(
ColumnParallelLinear
,
RowParallelLinear
)
class
AWQColumnParallelLinear
(
ColumnParallelLinear
):
...
...
vllm/model_executor/layers/sampler.py
View file @
ba0bfd40
...
...
@@ -5,8 +5,8 @@ import torch
import
torch.nn
as
nn
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.parallel_utils.
tensor_parallel
import
(
gather_from_
tensor_model_parallel_
region
)
from
vllm.model_executor.parallel_utils.
communication_op
import
(
tensor_model_parallel_
all_gather
)
from
vllm.sampling_params
import
SamplingParams
,
SamplingType
from
vllm.sequence
import
SamplerOutput
,
SequenceData
,
SequenceOutputs
...
...
@@ -92,7 +92,7 @@ def _get_logits(hidden_states: torch.Tensor, embedding: torch.Tensor,
logits
=
torch
.
matmul
(
hidden_states
,
embedding
.
t
())
if
embedding_bias
is
not
None
:
logits
+=
embedding_bias
logits
=
gather_from_
tensor_model_parallel_
region
(
logits
)
logits
=
tensor_model_parallel_
all_gather
(
logits
)
# Remove paddings in vocab (if any).
logits
=
logits
[:,
:
vocab_size
]
return
logits
...
...
vllm/model_executor/models/aquila.py
View file @
ba0bfd40
...
...
@@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.aquila
import
AquilaConfig
...
...
@@ -56,16 +57,18 @@ class AquilaMLP(nn.Module):
hidden_act
:
str
,
):
super
().
__init__
()
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -130,14 +133,12 @@ class AquilaAttention(nn.Module):
self
.
head_dim
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
...
...
@@ -230,7 +231,7 @@ class AquilaModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
)
self
.
layers
=
nn
.
ModuleList
([
AquilaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
...
...
@@ -270,11 +271,12 @@ class AquilaForCausalLM(nn.Module):
self
.
config
=
config
self
.
model
=
AquilaModel
(
config
)
vocab_size
=
((
config
.
vocab_size
+
63
)
//
64
)
*
64
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
vocab_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
vllm/model_executor/models/baichuan.py
View file @
ba0bfd40
...
...
@@ -39,8 +39,9 @@ from vllm.model_executor.weight_utils import (
load_padded_tensor_parallel_vocab
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.model_executor.parallel_utils.layers
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
)
from
vllm.sequence
import
SamplerOutput
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
...
...
@@ -81,16 +82,18 @@ class BaiChuanMLP(nn.Module):
hidden_act
:
str
,
):
super
().
__init__
()
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
)
self
.
gate_up_proj
=
ColumnParallelLinear
(
hidden_size
,
2
*
intermediate_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
down_proj
=
RowParallelLinear
(
intermediate_size
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
)
if
hidden_act
!=
"silu"
:
raise
ValueError
(
f
"Unsupported activation:
{
hidden_act
}
. "
"Only silu is supported for now."
)
...
...
@@ -133,14 +136,12 @@ class BaiChuanAttention(nn.Module):
3
*
hidden_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
,
)
self
.
o_proj
=
RowParallelLinear
(
self
.
total_num_heads
*
self
.
head_dim
,
hidden_size
,
bias
=
False
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
)
# Create the alibi slopes and slice them.
if
self
.
postion_embedding
==
"ALIBI"
:
...
...
@@ -249,7 +250,7 @@ class BaiChuanModel(nn.Module):
self
.
embed_tokens
=
VocabParallelEmbedding
(
config
.
vocab_size
,
config
.
hidden_size
,
perform_initialization
=
False
)
)
self
.
layers
=
nn
.
ModuleList
([
BaiChuanDecoderLayer
(
config
,
position_embedding
)
for
_
in
range
(
config
.
num_hidden_layers
)
...
...
@@ -288,11 +289,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
super
().
__init__
()
self
.
config
=
config
self
.
model
=
BaiChuanModel
(
config
,
position_embedding
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
...
...
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment