Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
0de7c2d0
Unverified
Commit
0de7c2d0
authored
Aug 08, 2024
by
Ying Sheng
Committed by
GitHub
Aug 08, 2024
Browse files
Add e5-mistral modules [unreachable code] - step 1/3 (#983)
parent
6ed4e3b8
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
136 additions
and
0 deletions
+136
-0
python/sglang/srt/layers/pooler.py
python/sglang/srt/layers/pooler.py
+50
-0
python/sglang/srt/models/llama_embedding.py
python/sglang/srt/models/llama_embedding.py
+86
-0
No files found.
python/sglang/srt/layers/pooler.py
0 → 100644
View file @
0de7c2d0
# adapted from
# https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
from
dataclasses
import
dataclass
from
enum
import
IntEnum
import
torch
import
torch.nn
as
nn
from
sglang.srt.model_executor.model_runner
import
InputMetadata
class
PoolingType
(
IntEnum
):
LAST
=
0
@
dataclass
class
EmbeddingPoolerOutput
:
embeddings
:
torch
.
Tensor
class
Pooler
(
nn
.
Module
):
"""A layer that pools specific information from hidden states.
This layer does the following:
1. Extracts specific tokens or aggregates data based on pooling method.
2. Normalizes output if specified.
3. Returns structured results as `PoolerOutput`.
Attributes:
pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
normalize: Whether to normalize the pooled data.
"""
def
__init__
(
self
,
pooling_type
:
PoolingType
,
normalize
:
bool
):
super
().
__init__
()
self
.
pooling_type
=
pooling_type
self
.
normalize
=
normalize
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
)
->
EmbeddingPoolerOutput
:
if
self
.
pooling_type
==
PoolingType
.
LAST
:
last_token_indices
=
torch
.
cumsum
(
input_metadata
.
extend_seq_lens
,
dim
=
0
)
-
1
pooled_data
=
hidden_states
[
last_token_indices
]
else
:
raise
ValueError
(
f
"Invalid pooling type:
{
self
.
pooling_type
}
"
)
if
self
.
normalize
:
pooled_data
=
nn
.
functional
.
normalize
(
pooled_data
,
p
=
2
,
dim
=
1
)
return
EmbeddingPoolerOutput
(
embeddings
=
pooled_data
)
python/sglang/srt/models/llama_embedding.py
0 → 100644
View file @
0de7c2d0
from
typing
import
Iterable
,
Optional
,
Tuple
import
torch
from
torch
import
nn
from
transformers
import
LlamaConfig
from
vllm.model_executor.model_loader.weight_utils
import
default_weight_loader
from
sglang.srt.layers.pooler
import
EmbeddingPoolerOutput
,
Pooler
,
PoolingType
from
sglang.srt.model_executor.model_runner
import
InputMetadata
from
sglang.srt.models.llama2
import
LlamaForCausalLM
,
LlamaModel
class
LlamaEmbeddingModel
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
,
quant_config
=
None
,
cache_config
=
None
,
efficient_weight_load
=
False
,
)
->
None
:
super
().
__init__
()
self
.
model
=
LlamaModel
(
config
,
quant_config
=
quant_config
)
self
.
pooler
=
Pooler
(
pooling_type
=
PoolingType
.
LAST
,
normalize
=
True
)
@
torch
.
no_grad
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
input_metadata
:
InputMetadata
,
input_embeds
:
torch
.
Tensor
=
None
,
)
->
EmbeddingPoolerOutput
:
hidden_states
=
self
.
model
(
input_ids
,
positions
,
input_metadata
,
input_embeds
)
return
self
.
pooler
(
hidden_states
,
input_metadata
)
def
load_weights
(
self
,
weights
:
Iterable
[
Tuple
[
str
,
torch
.
Tensor
]],
name
=
None
,
loaded_weight
=
None
):
stacked_params_mapping
=
[
# (param_name, shard_name, shard_id)
(
"qkv_proj"
,
"q_proj"
,
"q"
),
(
"qkv_proj"
,
"k_proj"
,
"k"
),
(
"qkv_proj"
,
"v_proj"
,
"v"
),
(
"gate_up_proj"
,
"gate_proj"
,
0
),
(
"gate_up_proj"
,
"up_proj"
,
1
),
]
params_dict
=
dict
(
self
.
model
.
named_parameters
())
def
load_weights_per_param
(
name
,
loaded_weight
):
if
"rotary_emb.inv_freq"
in
name
or
"projector"
in
name
:
return
if
"rotary_emb.cos_cached"
in
name
or
"rotary_emb.sin_cached"
in
name
:
# Models trained using ColossalAI may include these tensors in
# the checkpoint. Skip them.
return
for
param_name
,
weight_name
,
shard_id
in
stacked_params_mapping
:
if
weight_name
not
in
name
:
continue
name
=
name
.
replace
(
weight_name
,
param_name
)
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
continue
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
continue
param
=
params_dict
[
name
]
weight_loader
=
param
.
weight_loader
weight_loader
(
param
,
loaded_weight
,
shard_id
)
break
else
:
# Skip loading extra bias for GPTQ models.
if
name
.
endswith
(
".bias"
)
and
name
not
in
params_dict
:
return
if
name
.
startswith
(
"model.vision_tower"
)
and
name
not
in
params_dict
:
return
param
=
params_dict
[
name
]
weight_loader
=
getattr
(
param
,
"weight_loader"
,
default_weight_loader
)
weight_loader
(
param
,
loaded_weight
)
if
name
is
None
or
loaded_weight
is
None
:
for
name
,
loaded_weight
in
weights
:
load_weights_per_param
(
name
,
loaded_weight
)
else
:
load_weights_per_param
(
name
,
loaded_weight
)
EntryClass
=
LlamaEmbeddingModel
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment