Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
e73f76ee
Unverified
Commit
e73f76ee
authored
Aug 17, 2024
by
Besher Alkurdi
Committed by
GitHub
Aug 17, 2024
Browse files
[Model] Pipeline parallel support for JAIS (#7603)
parent
d95cc0a5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
51 additions
and
18 deletions
+51
-18
vllm/config.py
vllm/config.py
+1
-0
vllm/model_executor/models/jais.py
vllm/model_executor/models/jais.py
+50
-18
No files found.
vllm/config.py
View file @
e73f76ee
...
...
@@ -36,6 +36,7 @@ _PP_SUPPORTED_MODELS = [
"AquilaForCausalLM"
,
"DeepseekV2ForCausalLM"
,
"InternLMForCausalLM"
,
"JAISLMHeadModel"
,
"LlamaForCausalLM"
,
"LLaMAForCausalLM"
,
"MistralForCausalLM"
,
...
...
vllm/model_executor/models/jais.py
View file @
e73f76ee
...
...
@@ -20,14 +20,14 @@
"""Inference-only Jais model compatible with HuggingFace weights."""
import
math
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
from
typing
import
Iterable
,
List
,
Optional
,
Tuple
,
Union
import
torch
from
torch
import
nn
from
vllm.attention
import
Attention
,
AttentionMetadata
from
vllm.config
import
CacheConfig
from
vllm.distributed
import
(
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
get_pp_group
,
get_tensor_model_parallel_rank
,
get_tensor_model_parallel_world_size
)
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
QKVParallelLinear
,
...
...
@@ -43,6 +43,8 @@ from vllm.model_executor.sampling_metadata import SamplingMetadata
from
vllm.sequence
import
IntermediateTensors
,
SamplerOutput
from
vllm.transformers_utils.configs
import
JAISConfig
from
.utils
import
is_pp_missing_parameter
,
make_layers
class
SwiGLUActivation
(
nn
.
Module
):
...
...
@@ -216,6 +218,7 @@ class JAISModel(nn.Module):
config
:
JAISConfig
,
cache_config
:
Optional
[
CacheConfig
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
prefix
:
str
=
""
,
):
super
().
__init__
()
self
.
config
=
config
...
...
@@ -231,10 +234,15 @@ class JAISModel(nn.Module):
self
.
embeddings_scale
=
config
.
embeddings_scale
else
:
self
.
embeddings_scale
=
config
.
mup_embeddings_scale
self
.
h
=
nn
.
ModuleList
([
JAISBlock
(
config
,
cache_config
,
quant_config
)
for
_
in
range
(
config
.
num_hidden_layers
)
])
self
.
start_layer
,
self
.
end_layer
,
self
.
h
=
make_layers
(
config
.
num_hidden_layers
,
lambda
prefix
:
JAISBlock
(
config
=
config
,
cache_config
=
cache_config
,
quant_config
=
quant_config
),
prefix
=
f
"
{
prefix
}
.h"
,
)
self
.
ln_f
=
nn
.
LayerNorm
(
self
.
embed_dim
,
eps
=
config
.
layer_norm_epsilon
)
def
forward
(
...
...
@@ -243,19 +251,29 @@ class JAISModel(nn.Module):
position_ids
:
torch
.
Tensor
,
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
)
->
torch
.
Tensor
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]:
if
get_pp_group
().
is_first_rank
:
inputs_embeds
=
self
.
wte
(
input_ids
)
if
self
.
wpe
is
not
None
:
position_embeds
=
self
.
wpe
(
position_ids
)
hidden_states
=
inputs_embeds
+
position_embeds
else
:
hidden_states
=
inputs_embeds
hidden_states
*=
torch
.
tensor
(
float
(
self
.
embeddings_scale
),
dtype
=
hidden_states
.
dtype
)
else
:
hidden_states
=
inputs_embeds
hidden_states
*=
torch
.
tensor
(
float
(
self
.
embeddings_scale
),
dtype
=
hidden_states
.
dtype
)
assert
intermediate_tensors
is
not
None
hidden_states
=
intermediate_tensors
[
"hidden_states"
]
for
i
in
range
(
len
(
self
.
h
)
):
for
i
in
range
(
self
.
start_layer
,
self
.
end_layer
):
layer
=
self
.
h
[
i
]
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
],
attn_metadata
)
hidden_states
=
layer
(
hidden_states
,
kv_caches
[
i
-
self
.
start_layer
],
attn_metadata
)
if
not
get_pp_group
().
is_last_rank
:
return
IntermediateTensors
({
"hidden_states"
:
hidden_states
})
hidden_states
=
self
.
ln_f
(
hidden_states
)
return
hidden_states
...
...
@@ -290,9 +308,9 @@ class JAISLMHeadModel(nn.Module):
kv_caches
:
List
[
torch
.
Tensor
],
attn_metadata
:
AttentionMetadata
,
intermediate_tensors
:
Optional
[
IntermediateTensors
]
=
None
,
)
->
torch
.
Tensor
:
)
->
Union
[
IntermediateTensors
,
torch
.
Tensor
]
:
hidden_states
=
self
.
transformer
(
input_ids
,
positions
,
kv_caches
,
attn_metadata
)
attn_metadata
,
intermediate_tensors
)
return
hidden_states
def
compute_logits
(
...
...
@@ -304,6 +322,16 @@ class JAISLMHeadModel(nn.Module):
sampling_metadata
)
return
logits
def
make_empty_intermediate_tensors
(
self
,
batch_size
:
int
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
)
->
IntermediateTensors
:
return
IntermediateTensors
({
"hidden_states"
:
torch
.
zeros
((
batch_size
,
self
.
config
.
hidden_size
),
dtype
=
dtype
,
device
=
device
),
})
def
sample
(
self
,
logits
:
torch
.
Tensor
,
...
...
@@ -327,6 +355,10 @@ class JAISLMHeadModel(nn.Module):
continue
if
not
name
.
startswith
(
"transformer."
):
name
=
"transformer."
+
name
if
is_pp_missing_parameter
(
name
,
self
):
continue
param
=
params_dict
[
name
]
# The HF's GPT-2 implementation uses Conv1D instead of Linear.
# Because of this, we need to transpose the weights.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment