Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
1b0bd0fe
Unverified
Commit
1b0bd0fe
authored
Aug 02, 2023
by
Zhuohan Li
Committed by
GitHub
Aug 02, 2023
Browse files
Add Falcon support (new) (#592)
parent
20044cab
Changes
16
Hide whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
680 additions
and
122 deletions
+680
-122
README.md
README.md
+1
-0
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+29
-13
docs/source/models/supported_models.rst
docs/source/models/supported_models.rst
+3
-0
examples/llm_engine_example.py
examples/llm_engine_example.py
+2
-1
vllm/config.py
vllm/config.py
+7
-2
vllm/model_executor/layers/attention.py
vllm/model_executor/layers/attention.py
+24
-12
vllm/model_executor/model_loader.py
vllm/model_executor/model_loader.py
+2
-0
vllm/model_executor/models/__init__.py
vllm/model_executor/models/__init__.py
+2
-0
vllm/model_executor/models/falcon.py
vllm/model_executor/models/falcon.py
+496
-0
vllm/model_executor/parallel_utils/parallel_state.py
vllm/model_executor/parallel_utils/parallel_state.py
+1
-72
vllm/model_executor/parallel_utils/tensor_parallel/__init__.py
...model_executor/parallel_utils/tensor_parallel/__init__.py
+2
-1
vllm/model_executor/parallel_utils/tensor_parallel/layers.py
vllm/model_executor/parallel_utils/tensor_parallel/layers.py
+16
-15
vllm/transformers_utils/config.py
vllm/transformers_utils/config.py
+2
-0
vllm/transformers_utils/configs/__init__.py
vllm/transformers_utils/configs/__init__.py
+5
-0
vllm/transformers_utils/configs/falcon.py
vllm/transformers_utils/configs/falcon.py
+87
-0
vllm/worker/worker.py
vllm/worker/worker.py
+1
-6
No files found.
README.md
View file @
1b0bd0fe
...
@@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
...
@@ -44,6 +44,7 @@ vLLM seamlessly supports many Huggingface models, including the following archit
-
Baichuan-7B (
`baichuan-inc/Baichuan-7B`
)
-
Baichuan-7B (
`baichuan-inc/Baichuan-7B`
)
-
BLOOM (
`bigscience/bloom`
,
`bigscience/bloomz`
, etc.)
-
BLOOM (
`bigscience/bloom`
,
`bigscience/bloomz`
, etc.)
-
Falcon (
`tiiuae/falcon-7b`
,
`tiiuae/falcon-40b`
,
`tiiuae/falcon-rw-7b`
, etc.)
-
GPT-2 (
`gpt2`
,
`gpt2-xl`
, etc.)
-
GPT-2 (
`gpt2`
,
`gpt2-xl`
, etc.)
-
GPT BigCode (
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
, etc.)
-
GPT BigCode (
`bigcode/starcoder`
,
`bigcode/gpt_bigcode-santacoder`
, etc.)
-
GPT-J (
`EleutherAI/gpt-j-6b`
,
`nomic-ai/gpt4all-j`
, etc.)
-
GPT-J (
`EleutherAI/gpt-j-6b`
,
`nomic-ai/gpt4all-j`
, etc.)
...
...
csrc/pos_encoding_kernels.cu
View file @
1b0bd0fe
...
@@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel(
...
@@ -10,7 +10,8 @@ __global__ void rotary_embedding_neox_kernel(
scalar_t
*
__restrict__
key
,
// [num_tokens, num_kv_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
rot_dim
,
const
int
stride
,
const
int
query_stride
,
const
int
key_stride
,
const
int
num_heads
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int
head_size
)
{
...
@@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel(
...
@@ -23,14 +24,14 @@ __global__ void rotary_embedding_neox_kernel(
const
int
nq
=
num_heads
*
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
token_head
=
token_idx
*
query_
stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_x
=
token_idx
*
query_
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
int
out_y
=
token_idx
*
query_
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
...
@@ -39,13 +40,27 @@ __global__ void rotary_embedding_neox_kernel(
...
@@ -39,13 +40,27 @@ __global__ void rotary_embedding_neox_kernel(
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
key_stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
key_stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
if
(
head_idx
<
num_kv_heads
)
{
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
}
}
}
}
}
...
@@ -62,8 +77,8 @@ void rotary_embedding_neox(
...
@@ -62,8 +77,8 @@ void rotary_embedding_neox(
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
stride
=
query
.
stride
(
0
);
int
query_
stride
=
query
.
stride
(
0
);
TORCH_CHECK
(
stride
=
=
key
.
stride
(
0
)
)
;
int
key_
stride
=
key
.
stride
(
0
);
dim3
grid
(
num_tokens
);
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
...
@@ -80,7 +95,8 @@ void rotary_embedding_neox(
...
@@ -80,7 +95,8 @@ void rotary_embedding_neox(
key
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
rot_dim
,
stride
,
query_stride
,
key_stride
,
num_heads
,
num_heads
,
num_kv_heads
,
num_kv_heads
,
head_size
);
head_size
);
...
...
docs/source/models/supported_models.rst
View file @
1b0bd0fe
...
@@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it.
...
@@ -20,6 +20,9 @@ Alongside each architecture, we include some popular models that use it.
* - :code:`BloomForCausalLM`
* - :code:`BloomForCausalLM`
- BLOOM, BLOOMZ, BLOOMChat
- BLOOM, BLOOMZ, BLOOMChat
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
- :code:`bigscience/bloom`, :code:`bigscience/bloomz`, etc.
* - :code:`FalconForCausalLM`
- Falcon
- :code:`tiiuae/falcon-7b``, :code:`tiiuae/falcon-40b`, :code:`tiiuae/falcon-rw-7b`, etc.
* - :code:`GPT2LMHeadModel`
* - :code:`GPT2LMHeadModel`
- GPT-2
- GPT-2
- :code:`gpt2`, :code:`gpt2-xl`, etc.
- :code:`gpt2`, :code:`gpt2-xl`, etc.
...
...
examples/llm_engine_example.py
View file @
1b0bd0fe
...
@@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
...
@@ -10,7 +10,8 @@ def main(args: argparse.Namespace):
# Test the following prompts.
# Test the following prompts.
test_prompts
=
[
test_prompts
=
[
(
"A robot may not injure a human being"
,
SamplingParams
()),
(
"A robot may not injure a human being"
,
SamplingParams
(
temperature
=
0.0
)),
(
"To be or not to be,"
,
(
"To be or not to be,"
,
SamplingParams
(
temperature
=
0.8
,
top_k
=
5
,
presence_penalty
=
0.2
)),
SamplingParams
(
temperature
=
0.8
,
top_k
=
5
,
presence_penalty
=
0.2
)),
(
"What is the meaning of life?"
,
(
"What is the meaning of life?"
,
...
...
vllm/config.py
View file @
1b0bd0fe
...
@@ -94,8 +94,13 @@ class ModelConfig:
...
@@ -94,8 +94,13 @@ class ModelConfig:
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
return
self
.
hf_config
.
hidden_size
//
self
.
hf_config
.
num_attention_heads
def
get_num_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
def
get_num_heads
(
self
,
parallel_config
:
"ParallelConfig"
)
->
int
:
# For GPTBigCode:
# For GPTBigCode & Falcon:
if
getattr
(
self
.
hf_config
,
"multi_query"
,
False
):
# Note: for falcon, when new_decoder_architecture is True, the
# multi_query flag is ignored and we use n_head_kv for the number of
# KV heads.
if
(
getattr
(
self
.
hf_config
,
"multi_query"
,
False
)
and
(
self
.
hf_config
.
model_type
==
"falcon"
and
not
getattr
(
self
.
hf_config
,
"new_decoder_architecture"
,
False
))):
# Multi-query attention, only one KV head.
# Multi-query attention, only one KV head.
return
1
return
1
# For Falcon:
# For Falcon:
...
...
vllm/model_executor/layers/attention.py
View file @
1b0bd0fe
...
@@ -314,14 +314,13 @@ class PagedAttentionWithRoPE(PagedAttention):
...
@@ -314,14 +314,13 @@ class PagedAttentionWithRoPE(PagedAttention):
class
PagedAttentionWithALiBi
(
PagedAttention
):
class
PagedAttentionWithALiBi
(
PagedAttention
):
"""PagedAttention with ALiBi attention bias."""
"""PagedAttention with ALiBi attention bias."""
def
__init__
(
def
__init__
(
self
,
self
,
num_heads
:
int
,
num_heads
:
int
,
head_size
:
int
,
head_size
:
int
,
scale
:
float
,
scale
:
float
,
slopes
:
List
[
float
],
slopes
:
List
[
float
],
num_kv_heads
:
Optional
[
int
]
=
None
)
->
None
:
)
->
None
:
super
().
__init__
(
num_heads
,
head_size
,
scale
,
num_kv_heads
)
super
().
__init__
(
num_heads
,
head_size
,
scale
)
assert
len
(
slopes
)
==
num_heads
assert
len
(
slopes
)
==
num_heads
slopes
=
torch
.
tensor
(
slopes
,
dtype
=
torch
.
float32
)
slopes
=
torch
.
tensor
(
slopes
,
dtype
=
torch
.
float32
)
...
@@ -334,6 +333,11 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -334,6 +333,11 @@ class PagedAttentionWithALiBi(PagedAttention):
# Generates ALiBi mask for each prompt.
# Generates ALiBi mask for each prompt.
for
prompt_len
in
input_metadata
.
prompt_lens
:
for
prompt_len
in
input_metadata
.
prompt_lens
:
bias
=
torch
.
arange
(
prompt_len
)
bias
=
torch
.
arange
(
prompt_len
)
# Note(zhuohan): HF uses
# `bias = bias[None, :].repeat(prompt_len, 1)`
# here. We find that both biases give the same results, but
# the bias below more accurately follows the original ALiBi
# paper.
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
[
None
,
:]
-
bias
[:,
None
]
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
bias
=
bias
.
to
(
self
.
alibi_slopes
.
device
)
...
@@ -363,10 +367,17 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -363,10 +367,17 @@ class PagedAttentionWithALiBi(PagedAttention):
Args:
Args:
output: shape = [num_prompt_tokens, num_heads, head_size]
output: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
query: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_heads, head_size]
key: shape = [num_prompt_tokens, num_
kv_
heads, head_size]
value: shape = [num_prompt_tokens, num_heads, head_size]
value: shape = [num_prompt_tokens, num_
kv_
heads, head_size]
input_metadata: metadata for paged attention.
input_metadata: metadata for paged attention.
"""
"""
if
self
.
num_kv_heads
!=
self
.
num_heads
:
# Project the key and value tensors to the desired number of heads.
key
=
torch
.
repeat_interleave
(
key
,
self
.
num_queries_per_kv
,
dim
=
1
)
value
=
torch
.
repeat_interleave
(
value
,
self
.
num_queries_per_kv
,
dim
=
1
)
# FIXME(woosuk): Because xformers does not support dynamic sequence
# FIXME(woosuk): Because xformers does not support dynamic sequence
# lengths with custom attention bias, we process each prompt one by
# lengths with custom attention bias, we process each prompt one by
# one. This is inefficient, especially when we have many short prompts.
# one. This is inefficient, especially when we have many short prompts.
...
@@ -400,9 +411,10 @@ class PagedAttentionWithALiBi(PagedAttention):
...
@@ -400,9 +411,10 @@ class PagedAttentionWithALiBi(PagedAttention):
Args:
Args:
output: shape = [num_generation_tokens, num_heads, head_size]
output: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
query: shape = [num_generation_tokens, num_heads, head_size]
key_cache: shape = [num_blocks, num_heads, head_size/x,
key_cache: shape = [num_blocks, num_
kv_
heads, head_size/x,
block_size, x]
block_size, x]
value_cache: shape = [num_blocks, num_heads, head_size, block_size]
value_cache: shape = [num_blocks, num_kv_heads, head_size,
block_size]
input_metadata: metadata for paged attention.
input_metadata: metadata for paged attention.
"""
"""
block_size
=
value_cache
.
shape
[
3
]
block_size
=
value_cache
.
shape
[
3
]
...
...
vllm/model_executor/model_loader.py
View file @
1b0bd0fe
...
@@ -14,6 +14,7 @@ _MODEL_REGISTRY = {
...
@@ -14,6 +14,7 @@ _MODEL_REGISTRY = {
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaiChuanForCausalLM"
:
BaiChuanForCausalLM
,
# baichuan-7b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BaichuanForCausalLM"
:
BaichuanForCausalLM
,
# baichuan-13b
"BloomForCausalLM"
:
BloomForCausalLM
,
"BloomForCausalLM"
:
BloomForCausalLM
,
"FalconForCausalLM"
:
FalconForCausalLM
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPT2LMHeadModel"
:
GPT2LMHeadModel
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTBigCodeForCausalLM"
:
GPTBigCodeForCausalLM
,
"GPTJForCausalLM"
:
GPTJForCausalLM
,
"GPTJForCausalLM"
:
GPTJForCausalLM
,
...
@@ -22,6 +23,7 @@ _MODEL_REGISTRY = {
...
@@ -22,6 +23,7 @@ _MODEL_REGISTRY = {
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
# For decapoda-research/llama-*
"LLaMAForCausalLM"
:
LlamaForCausalLM
,
# For decapoda-research/llama-*
"MPTForCausalLM"
:
MPTForCausalLM
,
"MPTForCausalLM"
:
MPTForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
"OPTForCausalLM"
:
OPTForCausalLM
,
"RWForCausalLM"
:
FalconForCausalLM
,
}
}
...
...
vllm/model_executor/models/__init__.py
View file @
1b0bd0fe
from
vllm.model_executor.models.baichuan
import
BaiChuanForCausalLM
,
BaichuanForCausalLM
from
vllm.model_executor.models.baichuan
import
BaiChuanForCausalLM
,
BaichuanForCausalLM
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.bloom
import
BloomForCausalLM
from
vllm.model_executor.models.falcon
import
FalconForCausalLM
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt2
import
GPT2LMHeadModel
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_bigcode
import
GPTBigCodeForCausalLM
from
vllm.model_executor.models.gpt_j
import
GPTJForCausalLM
from
vllm.model_executor.models.gpt_j
import
GPTJForCausalLM
...
@@ -12,6 +13,7 @@ __all__ = [
...
@@ -12,6 +13,7 @@ __all__ = [
"BaiChuanForCausalLM"
,
"BaiChuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BaichuanForCausalLM"
,
"BloomForCausalLM"
,
"BloomForCausalLM"
,
"FalconForCausalLM"
,
"GPT2LMHeadModel"
,
"GPT2LMHeadModel"
,
"GPTBigCodeForCausalLM"
,
"GPTBigCodeForCausalLM"
,
"GPTJForCausalLM"
,
"GPTJForCausalLM"
,
...
...
vllm/model_executor/models/falcon.py
0 → 100644
View file @
1b0bd0fe
# coding=utf-8
# Adapted from
# https://github.com/huggingface/transformers/blob/a5cc30d72ae2dc19af534e4b35c986cc28db1275/src/transformers/models/falcon/modeling_falcon.py
# Copyright 2023 The vLLM team.
# Copyright 2023 the Falcon authors and HuggingFace Inc. team. All rights
# reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""PyTorch Falcon model."""
import
math
from
typing
import
Dict
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
torch.nn
import
LayerNorm
from
transformers
import
FalconConfig
as
HF_FalconConfig
from
vllm.model_executor.input_metadata
import
InputMetadata
from
vllm.model_executor.layers.attention
import
(
PagedAttention
,
PagedAttentionWithALiBi
,
PagedAttentionWithRoPE
)
from
vllm.model_executor.layers.sampler
import
Sampler
from
vllm.model_executor.weight_utils
import
(
hf_model_weights_iterator
,
load_tensor_parallel_weights
)
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.parallel_utils.tensor_parallel
import
(
VocabParallelEmbedding
,
ColumnParallelLinear
,
RowParallelLinear
,
reduce_from_tensor_model_parallel_region
)
from
vllm.sequence
import
SequenceOutputs
from
vllm.transformers_utils.configs
import
RWConfig
KVCache
=
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]
FalconConfig
=
Union
[
HF_FalconConfig
,
RWConfig
]
# NOTE(Hesslow): Unfortunately we did not fuse matmul and bias during
# training, this means that there's one additional quantization to bfloat16
# between the operations. In order not to degrade the quality of our HF-port,
# we keep these characteristics in the final model.
class
FalconLinear
(
nn
.
Linear
):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
hidden_states
=
x
@
self
.
weight
.
T
if
self
.
bias
is
None
:
return
hidden_states
return
hidden_states
+
self
.
bias
def
_get_alibi_slopes
(
total_num_heads
:
int
)
->
torch
.
Tensor
:
closest_power_of_2
=
2
**
math
.
floor
(
math
.
log2
(
total_num_heads
))
base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
)
powers
=
torch
.
arange
(
1
,
1
+
closest_power_of_2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
pow
(
base
,
powers
)
if
closest_power_of_2
!=
total_num_heads
:
extra_base
=
torch
.
tensor
(
2
**
(
-
(
2
**-
(
math
.
log2
(
2
*
closest_power_of_2
)
-
3
))),
dtype
=
torch
.
float32
)
num_remaining_heads
=
min
(
closest_power_of_2
,
total_num_heads
-
closest_power_of_2
)
extra_powers
=
torch
.
arange
(
1
,
1
+
2
*
num_remaining_heads
,
2
,
dtype
=
torch
.
int32
)
slopes
=
torch
.
cat
(
[
slopes
,
torch
.
pow
(
extra_base
,
extra_powers
)],
dim
=
0
)
return
slopes
class
FalconAttention
(
nn
.
Module
):
def
__init__
(
self
,
config
:
FalconConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
total_num_heads
=
config
.
num_attention_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_heads
=
self
.
total_num_heads
//
tp_size
self
.
head_dim
=
self
.
hidden_size
//
self
.
total_num_heads
assert
self
.
head_dim
*
self
.
total_num_heads
==
self
.
hidden_size
self
.
new_decoder_architecture
=
config
.
new_decoder_architecture
self
.
multi_query
=
config
.
multi_query
if
self
.
new_decoder_architecture
:
self
.
total_num_kv_heads
=
config
.
num_kv_heads
assert
self
.
total_num_heads
%
tp_size
==
0
self
.
num_kv_heads
=
self
.
total_num_kv_heads
//
tp_size
self
.
query_key_value
=
ColumnParallelLinear
(
self
.
hidden_size
,
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_dim
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
)
elif
self
.
multi_query
:
self
.
total_num_kv_heads
=
1
self
.
num_kv_heads
=
1
self
.
query
=
ColumnParallelLinear
(
self
.
hidden_size
,
self
.
total_num_heads
*
self
.
head_dim
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
)
self
.
key_value
=
FalconLinear
(
self
.
hidden_size
,
2
*
self
.
head_dim
,
bias
=
config
.
bias
)
else
:
self
.
total_num_kv_heads
=
self
.
total_num_heads
self
.
num_kv_heads
=
self
.
num_heads
self
.
query_key_value
=
ColumnParallelLinear
(
self
.
hidden_size
,
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_dim
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
)
self
.
q_size
=
self
.
num_heads
*
self
.
head_dim
self
.
kv_size
=
self
.
num_kv_heads
*
self
.
head_dim
# Layer-wise attention scaling
self
.
inv_norm_factor
=
1.0
/
math
.
sqrt
(
self
.
head_dim
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
self
.
dense
=
RowParallelLinear
(
self
.
hidden_size
,
self
.
hidden_size
,
bias
=
config
.
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
reduce_results
=
self
.
reduce_row_parallel_results
)
self
.
use_rotary
=
config
.
rotary
self
.
use_alibi
=
config
.
alibi
assert
not
(
self
.
use_rotary
and
self
.
use_alibi
),
(
"Rotary and alibi are mutually exclusive."
)
if
self
.
use_rotary
:
# TODO(zhuohan): Pass in correct `max_position``
self
.
attn
=
PagedAttentionWithRoPE
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
rotary_dim
=
self
.
head_dim
,
num_kv_heads
=
self
.
num_kv_heads
)
elif
self
.
use_alibi
:
tp_rank
=
get_tensor_model_parallel_rank
()
head_start
=
tp_rank
*
self
.
num_heads
head_end
=
(
tp_rank
+
1
)
*
self
.
num_heads
alibi_slopes
=
(
_get_alibi_slopes
(
self
.
total_num_heads
)
*
self
.
inv_norm_factor
)
alibi_slopes
=
alibi_slopes
[
head_start
:
head_end
].
tolist
()
self
.
attn
=
PagedAttentionWithALiBi
(
self
.
num_heads
,
self
.
head_dim
,
self
.
inv_norm_factor
,
alibi_slopes
,
num_kv_heads
=
self
.
num_kv_heads
)
else
:
self
.
attn
=
PagedAttention
(
self
.
num_heads
,
self
.
head_dim
,
scale
=
self
.
inv_norm_factor
,
num_kv_heads
=
self
.
num_kv_heads
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
)
->
torch
.
Tensor
:
if
not
self
.
new_decoder_architecture
and
self
.
multi_query
:
q
,
bias
=
self
.
query
(
hidden_states
)
if
bias
is
not
None
:
q
+=
bias
kv
=
self
.
key_value
(
hidden_states
)
k
,
v
=
kv
.
split
([
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
else
:
qkv
,
bias
=
self
.
query_key_value
(
hidden_states
)
if
bias
is
not
None
:
qkv
+=
bias
q
,
k
,
v
=
qkv
.
split
([
self
.
q_size
,
self
.
kv_size
,
self
.
kv_size
],
dim
=-
1
)
k_cache
,
v_cache
=
kv_cache
if
self
.
use_rotary
:
attn_output
=
self
.
attn
(
positions
,
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
else
:
attn_output
=
self
.
attn
(
q
,
k
,
v
,
k_cache
,
v_cache
,
input_metadata
,
cache_event
)
attn_output
,
bias
=
self
.
dense
(
attn_output
)
return
attn_output
,
bias
class
FalconMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
:
FalconConfig
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
dense_h_to_4h
=
ColumnParallelLinear
(
hidden_size
,
4
*
hidden_size
,
bias
=
config
.
bias
,
gather_output
=
False
,
perform_initialization
=
False
,
skip_bias_add
=
True
)
self
.
act
=
nn
.
GELU
()
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
self
.
dense_4h_to_h
=
RowParallelLinear
(
4
*
hidden_size
,
hidden_size
,
bias
=
config
.
bias
,
input_is_parallel
=
True
,
perform_initialization
=
False
,
skip_bias_add
=
True
,
reduce_results
=
self
.
reduce_row_parallel_results
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE(zhuohan): Following huggingface, we do not fuse bias add here.
x
,
bias
=
self
.
dense_h_to_4h
(
x
)
if
bias
is
not
None
:
x
+=
bias
x
=
self
.
act
(
x
)
x
,
bias
=
self
.
dense_4h_to_h
(
x
)
return
x
,
bias
class
FalconDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
FalconConfig
):
super
().
__init__
()
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
self_attention
=
FalconAttention
(
config
)
self
.
mlp
=
FalconMLP
(
config
)
self
.
config
=
config
if
config
.
new_decoder_architecture
:
# The layer norm before self-attention
self
.
ln_attn
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
# The layer norm before the MLP
self
.
ln_mlp
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
else
:
self
.
input_layernorm
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
if
not
config
.
parallel_attn
:
self
.
post_attention_layernorm
=
LayerNorm
(
hidden_size
,
eps
=
config
.
layer_norm_epsilon
)
self
.
reduce_row_parallel_results
=
not
(
config
.
new_decoder_architecture
or
config
.
parallel_attn
)
def
forward
(
self
,
positions
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
kv_cache
:
KVCache
,
input_metadata
:
InputMetadata
,
cache_event
:
Optional
[
torch
.
cuda
.
Event
],
):
residual
=
hidden_states
if
self
.
config
.
new_decoder_architecture
:
attention_layernorm_out
=
self
.
ln_attn
(
hidden_states
)
mlp_layernorm_out
=
self
.
ln_mlp
(
hidden_states
)
else
:
attention_layernorm_out
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
attention_output
,
attention_bias
=
self
.
self_attention
(
positions
=
positions
,
hidden_states
=
attention_layernorm_out
,
kv_cache
=
kv_cache
,
input_metadata
=
input_metadata
,
cache_event
=
cache_event
,
)
if
self
.
reduce_row_parallel_results
and
attention_bias
is
not
None
:
attention_output
+=
attention_bias
if
not
self
.
config
.
new_decoder_architecture
:
if
self
.
config
.
parallel_attn
:
mlp_layernorm_out
=
attention_layernorm_out
else
:
residual
+=
attention_output
mlp_layernorm_out
=
self
.
post_attention_layernorm
(
residual
)
# MLP.
mlp_output
,
mlp_bias
=
self
.
mlp
(
mlp_layernorm_out
)
if
self
.
reduce_row_parallel_results
and
mlp_bias
is
not
None
:
mlp_output
+=
mlp_bias
if
not
self
.
reduce_row_parallel_results
:
# When MLP and Attention layers are parallel, we can use
# only one all-reduce operator to reduce the results from
# both MLP and Attention layers.
mlp_output
+=
attention_output
mlp_output
=
reduce_from_tensor_model_parallel_region
(
mlp_output
)
if
attention_bias
is
not
None
:
mlp_output
+=
attention_bias
if
mlp_bias
is
not
None
:
mlp_output
+=
mlp_bias
output
=
mlp_output
+
residual
return
output
class
FalconModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
FalconConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
embed_dim
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
use_alibi
=
config
.
alibi
# Embedding + LN Embedding
self
.
word_embeddings
=
VocabParallelEmbedding
(
config
.
vocab_size
,
self
.
embed_dim
,
perform_initialization
=
False
)
# Transformer blocks
self
.
h
=
nn
.
ModuleList
([
FalconDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
# Final Layer Norm
self
.
ln_f
=
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
torch
.
Tensor
:
hidden_states
=
self
.
word_embeddings
(
input_ids
)
for
i
in
range
(
len
(
self
.
h
)):
if
cache_events
is
None
:
cache_event
=
None
else
:
cache_event
=
cache_events
[
i
]
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
positions
,
hidden_states
,
kv_caches
[
i
],
input_metadata
,
cache_event
,
)
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
class
FalconForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
config
:
FalconConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
transformer
=
FalconModel
(
config
)
self
.
lm_head
=
ColumnParallelLinear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
,
gather_output
=
False
,
perform_initialization
=
False
)
self
.
sampler
=
Sampler
(
config
.
vocab_size
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
,
positions
:
torch
.
Tensor
,
kv_caches
:
List
[
KVCache
],
input_metadata
:
InputMetadata
,
cache_events
:
Optional
[
List
[
torch
.
cuda
.
Event
]],
)
->
Dict
[
int
,
SequenceOutputs
]:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
input_metadata
,
cache_events
,
)
next_tokens
=
self
.
sampler
(
self
.
lm_head
.
weight
,
hidden_states
,
input_metadata
)
return
next_tokens
_column_parallel_weights
=
[
"word_embeddings.weight"
,
"lm_head.weight"
,
"dense_h_to_4h.weight"
,
"dense_h_to_4h.bias"
]
_row_parallel_weights
=
[
"dense.weight"
,
"dense_4h_to_h.weight"
]
def
load_weights
(
self
,
model_name_or_path
:
str
,
cache_dir
:
Optional
[
str
]
=
None
,
use_np_cache
:
bool
=
False
):
tp_size
=
(
get_tensor_model_parallel_world_size
())
tp_rank
=
get_tensor_model_parallel_rank
()
hidden_size
=
self
.
config
.
hidden_size
total_num_heads
=
self
.
config
.
num_attention_heads
num_heads
=
total_num_heads
//
tp_size
head_size
=
hidden_size
//
total_num_heads
head_start
=
tp_rank
*
num_heads
head_end
=
(
tp_rank
+
1
)
*
num_heads
if
self
.
config
.
new_decoder_architecture
:
total_num_kv_heads
=
self
.
config
.
num_kv_heads
num_kv_heads
=
total_num_kv_heads
//
tp_size
separated_q_kv
=
False
kv_head_start
=
tp_rank
*
num_kv_heads
kv_head_end
=
(
tp_rank
+
1
)
*
num_kv_heads
elif
self
.
config
.
multi_query
:
total_num_kv_heads
=
1
num_kv_heads
=
1
separated_q_kv
=
True
kv_head_start
=
0
kv_head_end
=
1
else
:
total_num_kv_heads
=
total_num_heads
num_kv_heads
=
total_num_kv_heads
//
tp_size
separated_q_kv
=
False
kv_head_start
=
tp_rank
*
num_kv_heads
kv_head_end
=
(
tp_rank
+
1
)
*
num_kv_heads
num_query_heads_per_kv_head
=
total_num_heads
//
total_num_kv_heads
state_dict
=
self
.
state_dict
()
for
name
,
loaded_weight
in
hf_model_weights_iterator
(
model_name_or_path
,
cache_dir
,
use_np_cache
):
if
"query_key_value"
in
name
:
loaded_weight_size
=
loaded_weight
.
size
()
loaded_weight
=
loaded_weight
.
view
(
total_num_kv_heads
,
num_query_heads_per_kv_head
+
2
,
head_size
,
*
loaded_weight_size
[
1
:])
wq
=
loaded_weight
[:,
:
-
2
].
reshape
(
-
1
,
*
loaded_weight_size
[
1
:])
wk
=
loaded_weight
[:,
[
-
2
]].
reshape
(
-
1
,
*
loaded_weight_size
[
1
:])
wv
=
loaded_weight
[:,
[
-
1
]].
reshape
(
-
1
,
*
loaded_weight_size
[
1
:])
wq
=
wq
[
head_size
*
head_start
:
head_size
*
head_end
]
wk
=
wk
[
head_size
*
kv_head_start
:
head_size
*
kv_head_end
]
wv
=
wv
[
head_size
*
kv_head_start
:
head_size
*
kv_head_end
]
if
separated_q_kv
:
loaded_weight_q
=
wq
loaded_weight_kv
=
torch
.
cat
([
wk
,
wv
],
dim
=
0
)
q_weight_name
=
name
.
replace
(
"query_key_value"
,
"query"
)
kv_weight_name
=
name
.
replace
(
"query_key_value"
,
"key_value"
)
load_tensor_parallel_weights
(
state_dict
[
q_weight_name
],
loaded_weight_q
,
q_weight_name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
)
load_tensor_parallel_weights
(
state_dict
[
kv_weight_name
],
loaded_weight_kv
,
kv_weight_name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
)
continue
else
:
loaded_weight
=
torch
.
cat
([
wq
,
wk
,
wv
],
dim
=
0
)
param
=
state_dict
[
name
]
load_tensor_parallel_weights
(
param
,
loaded_weight
,
name
,
self
.
_column_parallel_weights
,
self
.
_row_parallel_weights
,
tp_rank
)
vllm/model_executor/parallel_utils/parallel_state.py
View file @
1b0bd0fe
...
@@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
...
@@ -44,7 +44,6 @@ _PIPELINE_GLOBAL_RANKS = None
# rank when broadcasting weights from src to all other data parallel ranks
# rank when broadcasting weights from src to all other data parallel ranks
_DATA_PARALLEL_GLOBAL_RANKS
=
None
_DATA_PARALLEL_GLOBAL_RANKS
=
None
_ALL_REDUCE_LAUNCHER
:
Optional
[
'GraphAllReduce'
]
=
None
def
initialize_model_parallel
(
def
initialize_model_parallel
(
tensor_model_parallel_size
:
int
=
1
,
tensor_model_parallel_size
:
int
=
1
,
...
@@ -196,20 +195,6 @@ def initialize_model_parallel(
...
@@ -196,20 +195,6 @@ def initialize_model_parallel(
if
rank
in
ranks
:
if
rank
in
ranks
:
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
_POSITION_EMBEDDING_GLOBAL_RANKS
=
position_embedding_ranks
def
initialize_all_reduce_launcher
(
max_num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
disable_graph
:
bool
=
False
,
)
->
None
:
global
_ALL_REDUCE_LAUNCHER
_ALL_REDUCE_LAUNCHER
=
GraphAllReduce
(
max_num_tokens
=
max_num_tokens
,
hidden_size
=
hidden_size
,
dtype
=
dtype
,
disable_graph
=
disable_graph
,
)
def
model_parallel_is_initialized
():
def
model_parallel_is_initialized
():
"""Check if model and data parallel groups are initialized."""
"""Check if model and data parallel groups are initialized."""
if
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
\
if
_TENSOR_MODEL_PARALLEL_GROUP
is
None
or
\
...
@@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
...
@@ -458,6 +443,7 @@ def get_pipeline_model_parallel_last_rank():
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
last_rank_local
=
get_pipeline_model_parallel_world_size
()
-
1
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
return
_PIPELINE_GLOBAL_RANKS
[
last_rank_local
]
def
get_pipeline_model_parallel_next_rank
():
def
get_pipeline_model_parallel_next_rank
():
"""Return the global rank that follows the caller in the pipeline"""
"""Return the global rank that follows the caller in the pipeline"""
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
assert
_PIPELINE_GLOBAL_RANKS
is
not
None
,
\
...
@@ -485,10 +471,6 @@ def get_data_parallel_rank():
...
@@ -485,10 +471,6 @@ def get_data_parallel_rank():
"""Return my rank for the data parallel group."""
"""Return my rank for the data parallel group."""
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
return
torch
.
distributed
.
get_rank
(
group
=
get_data_parallel_group
())
def
get_all_reduce_launcher
()
->
'GraphAllReduce'
:
assert
_ALL_REDUCE_LAUNCHER
is
not
None
,
'all reduce launcher is not initialized'
return
_ALL_REDUCE_LAUNCHER
def
destroy_model_parallel
():
def
destroy_model_parallel
():
"""Set the groups to none."""
"""Set the groups to none."""
global
_MODEL_PARALLEL_GROUP
global
_MODEL_PARALLEL_GROUP
...
@@ -515,56 +497,3 @@ def destroy_model_parallel():
...
@@ -515,56 +497,3 @@ def destroy_model_parallel():
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
_MPU_TENSOR_MODEL_PARALLEL_RANK
=
None
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
global
_MPU_PIPELINE_MODEL_PARALLEL_RANK
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
_MPU_PIPELINE_MODEL_PARALLEL_RANK
=
None
class
GraphAllReduce
:
def
__init__
(
self
,
max_num_tokens
:
int
,
hidden_size
:
int
,
dtype
:
torch
.
dtype
,
disable_graph
:
bool
=
False
,
)
->
None
:
self
.
max_num_tokens
=
max_num_tokens
self
.
hidden_size
=
hidden_size
self
.
disable_graph
=
disable_graph
tp_world_size
=
get_tensor_model_parallel_world_size
()
if
tp_world_size
==
1
:
return
self
.
group
=
get_tensor_model_parallel_group
()
self
.
buffer
=
torch
.
empty
(
size
=
(
max_num_tokens
,
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
,
)
# Build graphs for different number of tokens.
if
not
self
.
disable_graph
:
self
.
graphs
=
{}
for
num_tokens
in
range
(
8
,
max_num_tokens
+
1
,
8
):
self
.
graphs
[
num_tokens
]
=
self
.
_build_graph
(
num_tokens
)
def
_build_graph
(
self
,
num_tokens
:
int
)
->
torch
.
cuda
.
CUDAGraph
:
# Warm up.
torch
.
distributed
.
all_reduce
(
self
.
buffer
[:
num_tokens
],
group
=
self
.
group
)
torch
.
cuda
.
synchronize
()
# Build graph.
graph
=
torch
.
cuda
.
CUDAGraph
()
with
torch
.
cuda
.
graph
(
graph
):
torch
.
distributed
.
all_reduce
(
self
.
buffer
[:
num_tokens
],
group
=
self
.
group
)
torch
.
cuda
.
synchronize
()
return
graph
def
launch
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
# NOTE: x must be a slice of self.buffer.
num_tokens
=
x
.
shape
[
0
]
if
self
.
disable_graph
:
torch
.
distributed
.
all_reduce
(
x
,
group
=
self
.
group
)
else
:
self
.
graphs
[
num_tokens
].
replay
()
return
x
vllm/model_executor/parallel_utils/tensor_parallel/__init__.py
View file @
1b0bd0fe
...
@@ -12,6 +12,7 @@ from .mappings import (
...
@@ -12,6 +12,7 @@ from .mappings import (
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_tensor_model_parallel_region
,
gather_from_sequence_parallel_region
,
gather_from_sequence_parallel_region
,
reduce_from_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
scatter_to_tensor_model_parallel_region
,
scatter_to_sequence_parallel_region
,
scatter_to_sequence_parallel_region
,
)
)
...
@@ -38,7 +39,7 @@ __all__ = [
...
@@ -38,7 +39,7 @@ __all__ = [
"copy_to_tensor_model_parallel_region"
,
"copy_to_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"gather_from_tensor_model_parallel_region"
,
"gather_from_sequence_parallel_region"
,
"gather_from_sequence_parallel_region"
,
#
"reduce_from_tensor_model_parallel_region",
"reduce_from_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_tensor_model_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
"scatter_to_sequence_parallel_region"
,
# random.py
# random.py
...
...
vllm/model_executor/parallel_utils/tensor_parallel/layers.py
View file @
1b0bd0fe
...
@@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
...
@@ -14,7 +14,6 @@ from torch.nn.parameter import Parameter
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
,
get_tensor_model_parallel_world_size
,
get_all_reduce_launcher
,
)
)
from
.mappings
import
(
from
.mappings
import
(
copy_to_tensor_model_parallel_region
,
copy_to_tensor_model_parallel_region
,
...
@@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -248,8 +247,8 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
world_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
self
.
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
if
params_dtype
is
None
:
...
@@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -350,6 +349,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype:
params_dtype:
use_cpu_initialization:
use_cpu_initialization:
perform_initialization:
perform_initialization:
reduce_results:
"""
"""
def
__init__
(
self
,
input_size
,
output_size
,
*
,
def
__init__
(
self
,
input_size
,
output_size
,
*
,
...
@@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -360,6 +360,7 @@ class RowParallelLinear(torch.nn.Module):
params_dtype
=
None
,
params_dtype
=
None
,
use_cpu_initialization
=
False
,
use_cpu_initialization
=
False
,
perform_initialization
=
True
,
perform_initialization
=
True
,
reduce_results
=
True
,
):
):
super
(
RowParallelLinear
,
self
).
__init__
()
super
(
RowParallelLinear
,
self
).
__init__
()
...
@@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -367,14 +368,19 @@ class RowParallelLinear(torch.nn.Module):
self
.
input_size
=
input_size
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
output_size
=
output_size
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
world_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
world_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
world_size
)
self
.
skip_bias_add
=
skip_bias_add
self
.
skip_bias_add
=
skip_bias_add
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
"results can lead to incorrect results"
)
# Parameters.
# Parameters.
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# Note: torch.nn.functional.linear performs XA^T + b and as a result
# we allocate the transpose.
# we allocate the transpose.
...
@@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -427,17 +433,12 @@ class RowParallelLinear(torch.nn.Module):
input_parallel
=
input_
input_parallel
=
input_
else
:
else
:
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
input_parallel
=
scatter_to_tensor_model_parallel_region
(
input_
)
if
get_tensor_model_parallel_world_size
()
==
1
:
# Matrix multiply.
# Matrix multiply.
output_parallel
=
F
.
linear
(
input_parallel
,
self
.
weight
)
output_
=
F
.
linear
(
input_parallel
,
self
.
weight
)
if
self
.
reduce_results
and
self
.
world_size
>
1
:
output_
=
reduce_from_tensor_model_parallel_region
(
output_parallel
)
else
:
else
:
# Matrix multiply.
output_
=
output_parallel
all_reduce_launcher
=
get_all_reduce_launcher
()
num_tokens
=
input_parallel
.
shape
[
0
]
output_buffer
=
all_reduce_launcher
.
buffer
[:
num_tokens
]
torch
.
matmul
(
input_parallel
,
self
.
weight_t
,
out
=
output_buffer
)
# All-reduce across all the partitions.
output_
=
all_reduce_launcher
.
launch
(
output_buffer
)
if
not
self
.
skip_bias_add
:
if
not
self
.
skip_bias_add
:
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
output
=
output_
+
self
.
bias
if
self
.
bias
is
not
None
else
output_
...
...
vllm/transformers_utils/config.py
View file @
1b0bd0fe
...
@@ -5,6 +5,8 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
...
@@ -5,6 +5,8 @@ from vllm.transformers_utils.configs import * # pylint: disable=wildcard-import
_CONFIG_REGISTRY
=
{
_CONFIG_REGISTRY
=
{
"mpt"
:
MPTConfig
,
"mpt"
:
MPTConfig
,
"baichuan"
:
BaiChuanConfig
,
"baichuan"
:
BaiChuanConfig
,
"RefinedWeb"
:
RWConfig
,
# For tiiuae/falcon-40b(-instruct)
"RefinedWebModel"
:
RWConfig
,
# For tiiuae/falcon-7b(-instruct)
}
}
...
...
vllm/transformers_utils/configs/__init__.py
View file @
1b0bd0fe
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.mpt
import
MPTConfig
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
from
vllm.transformers_utils.configs.baichuan
import
BaiChuanConfig
# RWConfig is for the original tiiuae/falcon-40b(-instruct) and
# tiiuae/falcon-7b(-instruct) models. Newer Falcon models will use the
# `FalconConfig` class from the official HuggingFace transformers library.
from
vllm.transformers_utils.configs.falcon
import
RWConfig
__all__
=
[
__all__
=
[
"MPTConfig"
,
"MPTConfig"
,
"BaiChuanConfig"
,
"BaiChuanConfig"
,
"RWConfig"
,
]
]
vllm/transformers_utils/configs/falcon.py
0 → 100644
View file @
1b0bd0fe
# Adapted from
# https://huggingface.co/tiiuae/falcon-7b/blob/main/configuration_RW.py
# Copyright 2023 The vLLM team.
# Copyright 2022 the Big Science Workshop and HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Falcon configuration"""
from
transformers.configuration_utils
import
PretrainedConfig
class
RWConfig
(
PretrainedConfig
):
model_type
=
"falcon"
keys_to_ignore_at_inference
=
[
"past_key_values"
]
attribute_map
=
{
"num_hidden_layers"
:
"n_layer"
,
"num_attention_heads"
:
"n_head"
,
"num_kv_heads"
:
"n_head_kv"
,
}
def
__init__
(
self
,
vocab_size
=
250880
,
hidden_size
=
64
,
n_layer
=
2
,
n_head
=
8
,
layer_norm_epsilon
=
1e-5
,
initializer_range
=
0.02
,
use_cache
=
True
,
bos_token_id
=
1
,
eos_token_id
=
2
,
hidden_dropout
=
0.0
,
attention_dropout
=
0.0
,
multi_query
=
True
,
n_head_kv
=
None
,
alibi
=
False
,
bias
=
False
,
parallel_attn
=
False
,
new_decoder_architecture
=
False
,
**
kwargs
,
)
->
None
:
self
.
vocab_size
=
vocab_size
# Backward compatibility with n_embed kwarg
n_embed
=
kwargs
.
pop
(
"n_embed"
,
None
)
self
.
hidden_size
=
hidden_size
if
n_embed
is
None
else
n_embed
self
.
n_layer
=
n_layer
self
.
n_head
=
n_head
self
.
layer_norm_epsilon
=
layer_norm_epsilon
self
.
initializer_range
=
initializer_range
self
.
use_cache
=
use_cache
self
.
hidden_dropout
=
hidden_dropout
self
.
attention_dropout
=
attention_dropout
self
.
bos_token_id
=
bos_token_id
self
.
eos_token_id
=
eos_token_id
self
.
multi_query
=
multi_query
self
.
n_head_kv
=
1
if
n_head_kv
is
None
else
n_head_kv
self
.
alibi
=
alibi
self
.
bias
=
bias
self
.
parallel_attn
=
parallel_attn
self
.
new_decoder_architecture
=
new_decoder_architecture
if
self
.
hidden_size
==
8192
:
# Hack for falcon-40b
self
.
new_decoder_architecture
=
True
super
().
__init__
(
bos_token_id
=
bos_token_id
,
eos_token_id
=
eos_token_id
,
**
kwargs
)
@
property
def
head_dim
(
self
):
return
self
.
hidden_size
//
self
.
n_head
@
property
def
rotary
(
self
):
return
not
self
.
alibi
vllm/worker/worker.py
View file @
1b0bd0fe
...
@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
...
@@ -9,7 +9,7 @@ from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig
)
SchedulerConfig
)
from
vllm.model_executor
import
get_model
,
InputMetadata
,
set_random_seed
from
vllm.model_executor
import
get_model
,
InputMetadata
,
set_random_seed
from
vllm.model_executor.parallel_utils.parallel_state
import
(
from
vllm.model_executor.parallel_utils.parallel_state
import
(
initialize_model_parallel
,
initialize_all_reduce_launcher
)
initialize_model_parallel
)
from
vllm.sampling_params
import
SamplingParams
from
vllm.sampling_params
import
SamplingParams
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
,
SequenceOutputs
from
vllm.sequence
import
SequenceData
,
SequenceGroupMetadata
,
SequenceOutputs
from
vllm.worker.cache_engine
import
CacheEngine
from
vllm.worker.cache_engine
import
CacheEngine
...
@@ -65,11 +65,6 @@ class Worker:
...
@@ -65,11 +65,6 @@ class Worker:
# Initialize the model.
# Initialize the model.
set_random_seed
(
self
.
model_config
.
seed
)
set_random_seed
(
self
.
model_config
.
seed
)
self
.
model
=
get_model
(
self
.
model_config
)
self
.
model
=
get_model
(
self
.
model_config
)
initialize_all_reduce_launcher
(
self
.
scheduler_config
.
max_num_batched_tokens
,
self
.
model_config
.
get_hidden_size
(),
self
.
model_config
.
dtype
,
)
@
torch
.
inference_mode
()
@
torch
.
inference_mode
()
def
profile_num_available_blocks
(
def
profile_num_available_blocks
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment