Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
1591c68f
Commit
1591c68f
authored
May 25, 2024
by
zhuwenwen
Browse files
merge v0.4.2
parents
09bcf00b
c7f2cf2b
Changes
265
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1049 additions
and
330 deletions
+1049
-330
vllm/lora/fully_sharded_layers.py
vllm/lora/fully_sharded_layers.py
+262
-0
vllm/lora/layers.py
vllm/lora/layers.py
+181
-127
vllm/lora/lora.py
vllm/lora/lora.py
+17
-11
vllm/lora/models.py
vllm/lora/models.py
+40
-36
vllm/lora/punica.py
vllm/lora/punica.py
+43
-0
vllm/lora/utils.py
vllm/lora/utils.py
+59
-1
vllm/lora/worker_manager.py
vllm/lora/worker_manager.py
+12
-9
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
...l_executor/guided_decoding/lm_format_enforcer_decoding.py
+1
-0
vllm/model_executor/guided_decoding/outlines_decoding.py
vllm/model_executor/guided_decoding/outlines_decoding.py
+5
-3
vllm/model_executor/guided_decoding/outlines_logits_processors.py
...el_executor/guided_decoding/outlines_logits_processors.py
+3
-4
vllm/model_executor/layers/activation.py
vllm/model_executor/layers/activation.py
+3
-0
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
...=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
+140
-0
vllm/model_executor/layers/fused_moe/fused_moe.py
vllm/model_executor/layers/fused_moe/fused_moe.py
+21
-13
vllm/model_executor/layers/layernorm.py
vllm/model_executor/layers/layernorm.py
+5
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+175
-89
vllm/model_executor/layers/logits_processor.py
vllm/model_executor/layers/logits_processor.py
+18
-15
vllm/model_executor/layers/quantization/__init__.py
vllm/model_executor/layers/quantization/__init__.py
+7
-4
vllm/model_executor/layers/quantization/aqlm.py
vllm/model_executor/layers/quantization/aqlm.py
+9
-6
vllm/model_executor/layers/quantization/awq.py
vllm/model_executor/layers/quantization/awq.py
+11
-8
vllm/model_executor/layers/quantization/base_config.py
vllm/model_executor/layers/quantization/base_config.py
+37
-4
No files found.
vllm/lora/fully_sharded_layers.py
0 → 100644
View file @
1591c68f
# pylint: disable=unused-argument
from
typing
import
TYPE_CHECKING
,
List
,
Optional
import
torch
import
torch.nn
as
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.distributed.communication_op
import
(
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
from
vllm.distributed.parallel_state
import
get_tensor_model_parallel_rank
from
vllm.lora.layers
import
(
ColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
)
from
vllm.lora.punica
import
bgmv
,
dispatch_bgmv_low_level
if
TYPE_CHECKING
:
pass
def
_fully_sharded_can_replace
(
can_replace
):
"""
decorator which adds the condition of fully sharded loras
intended to wrap can_replace_layer()
"""
def
dec
(
*
args
,
**
kwargs
):
return
(
can_replace
(
*
args
,
**
kwargs
)
and
kwargs
[
'lora_config'
].
fully_sharded_loras
)
return
dec
# these layers are based on the tensor parallelism strategy given in
# Y. Sheng et al., S-LoRA: Serving Thousands of Concurrent LoRA Adapters. 2023,
# https://arxiv.org/abs/2311.03285.
class
ColumnParallelLinearWithShardedLoRA
(
ColumnParallelLinearWithLoRA
):
"""
Differs from ColumnParallelLinearWithLoRA by slicing LoRA A also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tp_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
lora_a_stacked
.
shape
[
2
]
start_idx
=
tp_rank
*
shard_size
lora_a
=
lora_a
[:,
start_idx
:
start_idx
+
shard_size
]
return
lora_a
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffer
=
torch
.
zeros
((
x
.
shape
[
0
],
self
.
lora_a_stacked
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv
(
buffer
,
x
,
self
.
lora_a_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
)
buffer
=
tensor_model_parallel_all_gather
(
buffer
)
bgmv
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
)
# now have column partitioned output
output
=
output
.
view
(
*
out_orig_shape
)
return
output
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
def
_mcp_apply_weights
(
x
,
bias
,
layer
):
"""
MergedColumnParallelLinearWithShardedLoRA and
QKVParallelLinearWithShardedLora share the same
LoRa weight application method.
The main difference is the step by shard_size for lora_b which can
vary for QKVParallelLinearWithShardedLora but is constant for
MergedColumnParallelLinearWithShardedLoRA.
"""
# expecting 2 for column parallel and 3 for qkv
n
=
len
(
layer
.
lora_a_stacked
)
output
=
layer
.
base_layer
.
linear_method
.
apply_weights
(
layer
.
base_layer
,
x
,
bias
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffers
=
torch
.
zeros
((
n
,
x
.
shape
[
0
],
layer
.
lora_a_stacked
[
0
].
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
for
idx
in
range
(
n
):
bgmv
(
buffers
[
idx
],
x
,
layer
.
lora_a_stacked
[
idx
],
layer
.
indices
[:
layer
.
indices_len
[
0
]],
0
,
1.0
)
buffers
=
tensor_model_parallel_all_gather
(
buffers
)
left_offset
=
0
for
idx
in
range
(
n
):
shard_size
=
layer
.
lora_b_stacked
[
idx
].
shape
[
2
]
dispatch_bgmv_low_level
(
output
,
buffers
[
idx
],
layer
.
lora_b_stacked
[
idx
],
layer
.
indices
[:
layer
.
indices_len
[
0
]],
0
,
1.0
,
left_offset
,
shard_size
)
left_offset
+=
shard_size
output
=
output
.
view
(
*
out_orig_shape
)
# now have column partitioned and packed output
return
output
class
MergedColumnParallelLinearWithShardedLoRA
(
MergedColumnParallelLinearWithLoRA
):
"""
Differs from MergedColumnParallelLinearWithLoRA by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
output_shard_size
=
self
.
lora_a_stacked
[
0
].
shape
[
2
]
output_start_idx
=
self
.
tp_rank
*
output_shard_size
lora_a
=
[
lora_a
[
i
][:,
output_start_idx
:
output_start_idx
+
output_shard_size
]
for
i
in
range
(
2
)
]
return
lora_a
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply_weights
(
x
,
bias
,
self
)
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
class
MergedQKVParallelLinearWithShardedLora
(
MergedQKVParallelLinearWithLora
):
"""
Differs from QKVParallelLinearWithLora by slicing the
LoRA A's also.
Based on S-LoRA, slicing happens along the rank dim.
"""
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
shard_size
=
[
self
.
lora_a_stacked
[
i
].
shape
[
2
]
for
i
in
range
(
3
)]
start_idx
=
[
self
.
tp_rank
*
shard_size
[
i
]
for
i
in
range
(
3
)]
lora_a
=
[
lora_a
[
i
][:,
start_idx
[
i
]:
start_idx
[
i
]
+
shard_size
[
i
]]
if
lora_a
[
i
]
is
not
None
else
None
for
i
in
range
(
3
)
]
return
lora_a
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
return
_mcp_apply_weights
(
x
,
bias
,
self
)
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
class
RowParallelLinearWithShardedLoRA
(
RowParallelLinearWithLoRA
):
"""
Differs from RowParallelLinearWithLoRA by slicing the
LoRA B's also.
Based on S-LoRA, slicing happens along the output dim.
This yields a combined partial sum from the row parallel base
layer and column partitioned output from the LoRA.
"""
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
return
lora_b
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
self
.
base_layer
,
x
)
x
=
x
.
view
(
-
1
,
x
.
shape
[
-
1
])
output
,
out_orig_shape
=
output
.
view
(
-
1
,
output
.
shape
[
-
1
]),
output
.
shape
buffer
=
torch
.
zeros
((
x
.
shape
[
0
],
self
.
lora_a_stacked
.
shape
[
2
]),
dtype
=
torch
.
float32
,
device
=
x
.
device
)
bgmv
(
buffer
,
x
,
self
.
lora_a_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
)
buffer
=
tensor_model_parallel_all_reduce
(
buffer
)
# following S-LoRA, allows the fusing of all_gather and all_reduce
# by adding the column partitioned lora output to a slice of output
# tensor, which is a partial sum due to row parallel. All that
# remains is a standard all_reduce. User should be aware though that
# the output is not the same as a normal row_parallel, it should be
# reduced before being used
shard_size
=
self
.
lora_b_stacked
.
shape
[
2
]
start_idx
=
self
.
tp_rank
*
shard_size
dispatch_bgmv_low_level
(
output
,
buffer
,
self
.
lora_b_stacked
,
self
.
indices
[:
self
.
indices_len
[
0
]],
0
,
1.0
,
start_idx
,
shard_size
)
output
=
output
.
view
(
*
out_orig_shape
)
return
output
@
classmethod
@
_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# specifying kwargs so they can be easily accessed in decorator
return
super
().
can_replace_layer
(
source_layer
=
source_layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
,
decorate
=
False
,
)
vllm/lora/layers.py
View file @
1591c68f
# pylint: disable=unused-argument
# pylint: disable=unused-argument
import
inspect
import
math
import
math
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Set
,
Tuple
,
Typ
e
from
typing
import
TYPE_CHECKING
,
List
,
Optional
,
Tupl
e
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
...
@@ -16,6 +15,7 @@ from vllm.distributed import (get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_all_reduce
,
tensor_model_parallel_gather
)
tensor_model_parallel_gather
)
from
vllm.distributed.utils
import
divide
from
vllm.lora.punica
import
add_lora
,
add_lora_slice
,
bgmv
from
vllm.lora.punica
import
add_lora
,
add_lora_slice
,
bgmv
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
from
vllm.model_executor.layers.linear
import
(
ColumnParallelLinear
,
MergedColumnParallelLinear
,
MergedColumnParallelLinear
,
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
...
@@ -23,7 +23,7 @@ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear
)
RowParallelLinear
)
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
from
vllm.model_executor.layers.vocab_parallel_embedding
import
(
ParallelLMHead
,
VocabParallelEmbedding
)
VocabParallelEmbedding
)
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
pass
pass
...
@@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
...
@@ -45,6 +45,21 @@ def _get_lora_device(base_layer: nn.Module) -> torch.device:
raise
ValueError
(
f
"Unsupported base layer:
{
base_layer
}
"
)
raise
ValueError
(
f
"Unsupported base layer:
{
base_layer
}
"
)
def
_not_fully_sharded_can_replace
(
can_replace
):
"""
decorator which adds the condition of not using fully sharded loras
intended to wrap can_replace_layer()
"""
def
dec
(
*
args
,
**
kwargs
):
decorate
=
kwargs
.
pop
(
'decorate'
)
if
'decorate'
in
kwargs
else
True
condition
=
(
not
kwargs
[
'lora_config'
].
fully_sharded_loras
if
decorate
else
True
)
return
can_replace
(
*
args
,
**
kwargs
)
and
condition
return
dec
def
_apply_lora
(
def
_apply_lora
(
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
lora_a_stacked
:
torch
.
Tensor
,
...
@@ -130,6 +145,14 @@ class LoRAMapping:
...
@@ -130,6 +145,14 @@ class LoRAMapping:
class
BaseLayerWithLoRA
(
nn
.
Module
):
class
BaseLayerWithLoRA
(
nn
.
Module
):
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Slice lora a if splitting for tensor parallelism."""
...
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
"""Slice lora b if splitting with tensor parallelism."""
...
def
create_lora_weights
(
def
create_lora_weights
(
self
,
self
,
max_loras
:
int
,
max_loras
:
int
,
...
@@ -176,6 +199,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -176,6 +199,8 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
)
->
None
:
def
__init__
(
self
,
base_layer
:
VocabParallelEmbedding
)
->
None
:
super
().
__init__
()
super
().
__init__
()
self
.
base_layer
=
base_layer
self
.
base_layer
=
base_layer
self
.
embeddings_slice
:
Optional
[
Tuple
[
int
,
int
]]
self
.
embeddings_weights
:
Optional
[
torch
.
Tensor
]
def
create_lora_weights
(
def
create_lora_weights
(
self
,
self
,
...
@@ -233,9 +258,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -233,9 +258,10 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
lora_a_stacked
.
shape
[
0
]
*
self
.
lora_a_stacked
.
shape
[
1
],
self
.
lora_a_stacked
.
shape
[
0
]
*
self
.
lora_a_stacked
.
shape
[
1
],
self
.
lora_a_stacked
.
shape
[
2
],
self
.
lora_a_stacked
.
shape
[
2
],
)
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
# Lazily initialized.
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
self
.
indices
:
torch
.
Tensor
self
.
embeddings_indices
=
None
self
.
indices_len
:
List
[
int
]
self
.
embeddings_indices
:
torch
.
Tensor
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_a_stacked
[
index
]
=
0
...
@@ -267,6 +293,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -267,6 +293,7 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
self
.
embeddings_tensors
.
shape
[
1
],
self
.
embeddings_tensors
.
shape
[
1
],
self
.
embeddings_tensors
.
shape
[
2
]
self
.
embeddings_tensors
.
shape
[
2
]
)[
self
.
embeddings_slice
[
0
]:
self
.
embeddings_slice
[
1
]]
)[
self
.
embeddings_slice
[
0
]:
self
.
embeddings_slice
[
1
]]
assert
self
.
embeddings_weights
is
not
None
self
.
embeddings_weights
[:
embeddings
.
shape
[
0
]].
copy_
(
embeddings
)
self
.
embeddings_weights
[:
embeddings
.
shape
[
0
]].
copy_
(
embeddings
)
def
set_mapping
(
def
set_mapping
(
...
@@ -313,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
...
@@ -313,6 +340,11 @@ class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
class
ColumnParallelLinearWithLoRA
(
BaseLayerWithLoRA
):
"""
LoRA on top of ColumnParallelLinear layer.
LoRA B is sliced for tensor parallelism.
"""
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
)
->
None
:
def
__init__
(
self
,
base_layer
:
ColumnParallelLinear
)
->
None
:
super
().
__init__
()
super
().
__init__
()
...
@@ -327,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -327,10 +359,15 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras
:
int
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
self
.
lora_a_stacked
=
torch
.
zeros
(
self
.
lora_a_stacked
=
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -343,15 +380,27 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -343,15 +380,27 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
self
.
output_dim
=
self
.
lora_b_stacked
.
shape
[
2
]
self
.
output_dim
=
self
.
lora_b_stacked
.
shape
[
2
]
# lazily initialized.
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
output_dim
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
return
lora_b
def
set_lora
(
def
set_lora
(
self
,
self
,
index
:
int
,
index
:
int
,
...
@@ -360,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -360,12 +409,11 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
):
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
lora_a
=
self
.
slice_lora_a
(
lora_a
)
shard_size
=
self
.
output_dim
lora_b
=
self
.
slice_lora_b
(
lora_b
)
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[:,
start_idx
:
end_idx
]
self
.
lora_a_stacked
[
index
,
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
lora_a
.
T
,
non_blocking
=
True
)
lora_a
.
T
,
non_blocking
=
True
)
...
@@ -384,10 +432,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -384,10 +432,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
indices
=
base_indices
self
.
indices
=
base_indices
self
.
indices_len
=
indices_len
self
.
indices_len
=
indices_len
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
_apply_lora
(
_apply_lora
(
x
,
x
,
self
.
lora_a_stacked
,
self
.
lora_a_stacked
,
...
@@ -411,7 +458,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -411,7 +458,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
if
not
self
.
base_layer
.
skip_bias_add
else
None
)
if
not
self
.
base_layer
.
skip_bias_add
else
None
)
# Matrix multiply.
# Matrix multiply.
output_parallel
=
self
.
apply
_weights
(
input_
,
bias
)
output_parallel
=
self
.
apply
(
input_
,
bias
)
if
self
.
base_layer
.
gather_output
:
if
self
.
base_layer
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
...
@@ -422,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -422,6 +469,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
return
output
,
output_bias
return
output
,
output_bias
@
classmethod
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
@@ -447,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -447,6 +495,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
max_loras
:
int
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
n_slices
=
2
n_slices
=
2
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
n_slices
if
not
(
len
(
self
.
base_layer
.
output_sizes
)
==
n_slices
and
self
.
base_layer
.
output_sizes
[
0
]
and
self
.
base_layer
.
output_sizes
[
0
]
...
@@ -455,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -455,12 +504,17 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"LoRAColumnParallelLinear2Slice requires 2 slices with "
"the same size."
)
"the same size."
)
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
self
.
lora_a_stacked
=
tuple
(
self
.
lora_a_stacked
=
tuple
(
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -475,8 +529,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -475,8 +529,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
device
=
self
.
device
,
device
=
self
.
device
,
)
for
_
in
range
(
n_slices
))
)
for
_
in
range
(
n_slices
))
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
output_dim
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
self
.
output_dim
=
self
.
lora_b_stacked
[
0
].
shape
[
2
]
# Lazily initialized.
self
.
indices
:
torch
.
Tensor
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
0
][
index
]
=
0
self
.
lora_a_stacked
[
0
][
index
]
=
0
...
@@ -484,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -484,6 +539,18 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
0
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
self
.
lora_b_stacked
[
1
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
shard_size
=
self
.
output_dim
start_idx
=
self
.
tp_rank
*
shard_size
end_idx
=
(
self
.
tp_rank
+
1
)
*
shard_size
lora_b
=
[
lora_b
[
0
][:,
start_idx
:
end_idx
],
lora_b
[
1
][:,
start_idx
:
end_idx
]
]
return
lora_b
def
set_lora
(
def
set_lora
(
self
,
self
,
index
:
int
,
index
:
int
,
...
@@ -494,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -494,13 +561,8 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
lora_a
=
self
.
slice_lora_a
(
lora_a
)
shard_size
=
self
.
output_dim
lora_b
=
self
.
slice_lora_b
(
lora_b
)
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_b
=
lora_b
[
0
][:,
start_idx
:
end_idx
],
lora_b
[
1
][:,
start_idx
:
end_idx
]
if
lora_a
[
0
]
is
not
None
:
if
lora_a
[
0
]
is
not
None
:
self
.
lora_a_stacked
[
0
][
self
.
lora_a_stacked
[
0
][
...
@@ -517,10 +579,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -517,10 +579,9 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
index
,
0
,
:
lora_b
[
1
].
shape
[
1
],
:
lora_b
[
1
].
shape
[
0
]].
copy_
(
index
,
0
,
:
lora_b
[
1
].
shape
[
1
],
:
lora_b
[
1
].
shape
[
0
]].
copy_
(
lora_b
[
1
].
T
,
non_blocking
=
True
)
lora_b
[
1
].
T
,
non_blocking
=
True
)
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
_apply_lora_packed_nslice
(
_apply_lora_packed_nslice
(
x
,
x
,
self
.
lora_a_stacked
,
self
.
lora_a_stacked
,
...
@@ -532,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
...
@@ -532,6 +593,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
return
output
return
output
@
classmethod
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
@@ -623,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -623,21 +685,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
max_loras
:
int
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
self
.
q_proj_shard_size
=
(
self
.
base_layer
.
num_heads
*
self
.
base_layer
.
head_size
)
self
.
base_layer
.
head_size
)
self
.
kv_proj_shard_size
=
(
self
.
base_layer
.
num_kv_heads
*
self
.
kv_proj_shard_size
=
(
self
.
base_layer
.
num_kv_heads
*
self
.
base_layer
.
head_size
)
self
.
base_layer
.
head_size
)
self
.
q_shard_id
=
tp_rank
self
.
q_shard_id
=
self
.
tp_rank
self
.
kv_shard_id
=
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
self
.
kv_shard_id
=
self
.
tp_rank
//
self
.
base_layer
.
num_kv_head_replicas
lora_a_output_size_per_partition
=
(
lora_config
.
max_lora_rank
if
not
lora_config
.
fully_sharded_loras
else
divide
(
lora_config
.
max_lora_rank
,
self
.
tp_size
))
# q, k, v
# q, k, v
self
.
lora_a_stacked
=
(
self
.
lora_a_stacked
=
(
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -645,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -645,7 +711,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -653,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -653,7 +719,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
torch
.
zeros
(
torch
.
zeros
(
max_loras
,
max_loras
,
1
,
1
,
lora_
config
.
max_lora_rank
,
lora_
a_output_size_per_partition
,
self
.
input_size
,
self
.
input_size
,
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
...
@@ -690,7 +756,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -690,7 +756,8 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
kv_proj_shard_size
)
self
.
kv_proj_shard_size
)
self
.
packed_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
packed_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
standard_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
standard_indices
:
Optional
[
torch
.
Tensor
]
=
None
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
# lazily initialized.
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
0
][
index
]
=
0
self
.
lora_a_stacked
[
0
][
index
]
=
0
...
@@ -700,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -700,6 +767,25 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_a_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
self
.
lora_b_stacked
[
2
][
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
List
[
torch
.
Tensor
])
->
List
[
torch
.
Tensor
]:
if
lora_b
[
0
]
is
not
None
:
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
)]
if
lora_b
[
1
]
is
not
None
:
lora_b_k
=
lora_b
[
1
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
if
lora_b
[
2
]
is
not
None
:
lora_b_v
=
lora_b
[
2
][:,
self
.
kv_proj_shard_size
*
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
(
self
.
kv_shard_id
+
1
)]
lora_b
=
[
lora_b_q
,
lora_b_k
,
lora_b_v
]
return
lora_b
def
set_lora
(
def
set_lora
(
self
,
self
,
index
:
int
,
index
:
int
,
...
@@ -710,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -710,40 +796,24 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
if
self
.
tp_size
>
1
:
if
self
.
tp_size
>
1
:
if
lora_b
[
0
]
is
not
None
:
lora_a
=
self
.
slice_lora_a
(
lora_a
)
lora_b_q
=
lora_b
[
0
][:,
self
.
q_proj_shard_size
*
lora_b
=
self
.
slice_lora_b
(
lora_b
)
self
.
q_shard_id
:
self
.
q_proj_shard_size
*
(
self
.
q_shard_id
+
1
)]
if
lora_b
[
0
]
is
not
None
:
self
.
lora_b_stacked
[
0
][
lora_b_q
=
lora_b
[
0
]
index
,
0
,
:
lora_b_q
.
shape
[
1
],
:
lora_b_q
.
shape
[
0
]].
copy_
(
self
.
lora_b_stacked
[
0
][
lora_b_q
.
T
,
non_blocking
=
True
)
index
,
0
,
:
lora_b_q
.
shape
[
1
],
:
lora_b_q
.
shape
[
0
]].
copy_
(
if
lora_b
[
1
]
is
not
None
:
lora_b_q
.
T
,
non_blocking
=
True
)
lora_b_k
=
lora_b
[
1
][:,
self
.
kv_proj_shard_size
*
if
lora_b
[
1
]
is
not
None
:
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
lora_b_k
=
lora_b
[
1
]
(
self
.
kv_shard_id
+
1
)]
self
.
lora_b_stacked
[
1
][
self
.
lora_b_stacked
[
1
][
index
,
0
,
:
lora_b_k
.
shape
[
1
],
:
lora_b_k
.
shape
[
0
]].
copy_
(
index
,
0
,
:
lora_b_k
.
shape
[
1
],
:
lora_b_k
.
shape
[
0
]].
copy_
(
lora_b_k
.
T
,
non_blocking
=
True
)
lora_b_k
.
T
,
non_blocking
=
True
)
if
lora_b
[
2
]
is
not
None
:
if
lora_b
[
2
]
is
not
None
:
lora_b_v
=
lora_b
[
2
]
lora_b_v
=
lora_b
[
2
][:,
self
.
kv_proj_shard_size
*
self
.
lora_b_stacked
[
2
][
self
.
kv_shard_id
:
self
.
kv_proj_shard_size
*
index
,
0
,
:
lora_b_v
.
shape
[
1
],
:
lora_b_v
.
shape
[
0
]].
copy_
(
(
self
.
kv_shard_id
+
1
)]
lora_b_v
.
T
,
non_blocking
=
True
)
self
.
lora_b_stacked
[
2
][
index
,
0
,
:
lora_b_v
.
shape
[
1
],
:
lora_b_v
.
shape
[
0
]].
copy_
(
lora_b_v
.
T
,
non_blocking
=
True
)
else
:
if
lora_b
[
0
]
is
not
None
:
self
.
lora_b_stacked
[
0
][
index
,
0
,
:
lora_b
[
0
].
shape
[
1
],
:
lora_b
[
0
].
shape
[
0
]].
copy_
(
lora_b
[
0
].
T
,
non_blocking
=
True
)
if
lora_b
[
1
]
is
not
None
:
self
.
lora_b_stacked
[
1
][
index
,
0
,
:
lora_b
[
1
].
shape
[
1
],
:
lora_b
[
1
].
shape
[
0
]].
copy_
(
lora_b
[
1
].
T
,
non_blocking
=
True
)
if
lora_b
[
2
]
is
not
None
:
self
.
lora_b_stacked
[
2
][
index
,
0
,
:
lora_b
[
2
].
shape
[
1
],
:
lora_b
[
2
].
shape
[
0
]].
copy_
(
lora_b
[
2
].
T
,
non_blocking
=
True
)
if
lora_a
[
0
]
is
not
None
:
if
lora_a
[
0
]
is
not
None
:
self
.
lora_a_stacked
[
0
][
self
.
lora_a_stacked
[
0
][
...
@@ -758,10 +828,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -758,10 +828,9 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
index
,
0
,
:
lora_a
[
2
].
shape
[
1
],
:
lora_a
[
2
].
shape
[
0
]].
copy_
(
index
,
0
,
:
lora_a
[
2
].
shape
[
1
],
:
lora_a
[
2
].
shape
[
0
]].
copy_
(
lora_a
[
2
].
T
,
non_blocking
=
True
)
lora_a
[
2
].
T
,
non_blocking
=
True
)
def
apply_weights
(
self
,
x
:
torch
.
Tensor
,
def
apply
(
self
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
])
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
,
bias
)
self
.
base_layer
,
x
,
bias
)
_apply_lora_packed_nslice
(
_apply_lora_packed_nslice
(
x
,
x
,
self
.
lora_a_stacked
,
self
.
lora_a_stacked
,
...
@@ -773,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
...
@@ -773,6 +842,7 @@ class MergedQKVParallelLinearWithLora(ColumnParallelLinearWithLoRA):
return
output
return
output
@
classmethod
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
@@ -794,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -794,6 +864,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
max_loras
:
int
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
None
:
self
.
lora_config
=
lora_config
self
.
tp_rank
=
get_tensor_model_parallel_rank
()
self
.
lora_a_stacked
=
torch
.
zeros
(
self
.
lora_a_stacked
=
torch
.
zeros
(
(
(
max_loras
,
max_loras
,
...
@@ -804,23 +876,40 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -804,23 +876,40 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
tp_size
=
get_tensor_model_parallel_world_size
()
lora_b_output_size_per_partition
=
(
self
.
output_size
if
not
lora_config
.
fully_sharded_loras
else
divide
(
self
.
output_size
,
tp_size
))
self
.
lora_b_stacked
=
torch
.
zeros
(
self
.
lora_b_stacked
=
torch
.
zeros
(
(
(
max_loras
,
max_loras
,
1
,
1
,
self
.
output_size
,
lora_b_
output_size
_per_partition
,
lora_config
.
max_lora_rank
,
lora_config
.
max_lora_rank
,
),
),
dtype
=
lora_config
.
lora_dtype
,
dtype
=
lora_config
.
lora_dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
indices
:
Optional
[
torch
.
Tensor
]
=
None
# Lazily initialized
self
.
indices_len
:
Optional
[
List
[
int
]]
=
None
self
.
indices
:
torch
.
Tensor
self
.
indices_len
:
List
[
int
]
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
self
.
lora_b_stacked
[
index
]
=
0
def
slice_lora_a
(
self
,
lora_a
:
torch
.
Tensor
)
->
torch
.
Tensor
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
shard_size
=
self
.
input_size
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:]
return
lora_a
def
slice_lora_b
(
self
,
lora_b
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
lora_b
def
set_lora
(
def
set_lora
(
self
,
self
,
index
:
int
,
index
:
int
,
...
@@ -829,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -829,12 +918,10 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
embeddings_tensor
:
Optional
[
torch
.
Tensor
],
):
):
self
.
reset_lora
(
index
)
self
.
reset_lora
(
index
)
if
self
.
base_layer
.
tp_size
>
1
:
if
self
.
base_layer
.
tp_size
>
1
:
tensor_model_parallel_rank
=
get_tensor_model_parallel_rank
()
lora_a
=
self
.
slice_lora_a
(
lora_a
)
shard_size
=
self
.
input_size
lora_b
=
self
.
slice_lora_b
(
lora_b
)
start_idx
=
tensor_model_parallel_rank
*
shard_size
end_idx
=
(
tensor_model_parallel_rank
+
1
)
*
shard_size
lora_a
=
lora_a
[
start_idx
:
end_idx
,
:]
self
.
lora_a_stacked
[
index
,
self
.
lora_a_stacked
[
index
,
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
0
,
:
lora_a
.
shape
[
1
],
:
lora_a
.
shape
[
0
]].
copy_
(
...
@@ -854,9 +941,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -854,9 +941,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
indices
=
base_indices
self
.
indices
=
base_indices
self
.
indices_len
=
indices_len
self
.
indices_len
=
indices_len
def
apply_weights
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
apply
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
output
=
self
.
base_layer
.
linear_method
.
apply_weights
(
output
=
self
.
base_layer
.
quant_method
.
apply
(
self
.
base_layer
,
x
)
self
.
base_layer
,
x
)
_apply_lora
(
_apply_lora
(
x
,
x
,
self
.
lora_a_stacked
,
self
.
lora_a_stacked
,
...
@@ -889,7 +975,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -889,7 +975,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
# Matrix multiply.
output_parallel
=
self
.
apply
_weights
(
input_parallel
)
output_parallel
=
self
.
apply
(
input_parallel
)
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
if
self
.
base_layer
.
reduce_results
and
self
.
base_layer
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
...
@@ -911,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
...
@@ -911,6 +997,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
self
.
base_layer
,
"weight"
)
else
self
.
base_layer
.
qweight
@
classmethod
@
classmethod
@
_not_fully_sharded_can_replace
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
def
can_replace_layer
(
cls
,
source_layer
:
nn
.
Module
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
...
@@ -991,9 +1078,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -991,9 +1078,10 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
dtype
=
self
.
dtype
,
dtype
=
self
.
dtype
,
device
=
self
.
device
,
device
=
self
.
device
,
)
)
self
.
indices
=
None
# Lazily initialized.
self
.
indices_padded
=
None
self
.
indices
:
torch
.
Tensor
self
.
indices_len
=
None
self
.
indices_len
:
List
[
int
]
self
.
indices_padded
:
torch
.
Tensor
def
reset_lora
(
self
,
index
:
int
):
def
reset_lora
(
self
,
index
:
int
):
self
.
lora_a_stacked
[
index
]
=
0
self
.
lora_a_stacked
[
index
]
=
0
...
@@ -1091,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
...
@@ -1091,37 +1179,3 @@ class LogitsProcessorWithLoRA(BaseLayerWithLoRA):
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
model_config
:
Optional
[
PretrainedConfig
])
->
bool
:
# Special handling for the LogitsProcessor.
# Special handling for the LogitsProcessor.
return
False
return
False
_all_lora_classes
:
Set
[
Type
[
BaseLayerWithLoRA
]]
=
{
cls
for
cls
in
globals
().
values
()
if
inspect
.
isclass
(
cls
)
and
issubclass
(
cls
,
BaseLayerWithLoRA
)
and
cls
is
not
BaseLayerWithLoRA
}
def
from_layer
(
layer
:
nn
.
Module
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
nn
.
Module
:
for
lora_cls
in
_all_lora_classes
:
if
lora_cls
.
can_replace_layer
(
layer
,
lora_config
,
packed_modules_list
,
model_config
):
ret
=
lora_cls
(
layer
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
return
layer
def
from_layer_logits_processor
(
layer
:
LogitsProcessor
,
lm_head
:
ParallelLMHead
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
LogitsProcessorWithLoRA
:
ret
=
LogitsProcessorWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
vllm/lora/lora.py
View file @
1591c68f
...
@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
...
@@ -97,9 +97,9 @@ class PackedLoRALayerWeights(LoRALayerWeights):
self
,
self
,
module_name
:
str
,
module_name
:
str
,
rank
:
int
,
rank
:
int
,
lora_alphas
:
List
[
int
],
lora_alphas
:
List
[
Optional
[
int
]
]
,
lora_a
:
List
[
torch
.
Tensor
],
lora_a
:
List
[
Optional
[
torch
.
Tensor
]
]
,
lora_b
:
List
[
torch
.
Tensor
],
lora_b
:
List
[
Optional
[
torch
.
Tensor
]
]
,
scaling
:
Optional
[
List
[
float
]]
=
None
,
scaling
:
Optional
[
List
[
float
]]
=
None
,
)
->
None
:
)
->
None
:
super
().
__init__
(
super
().
__init__
(
...
@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
...
@@ -108,17 +108,20 @@ class PackedLoRALayerWeights(LoRALayerWeights):
lora_alpha
=
0
,
lora_alpha
=
0
,
lora_a
=
lora_a
,
lora_a
=
lora_a
,
lora_b
=
lora_b
,
lora_b
=
lora_b
,
scaling
=
scaling
,
scaling
=
scaling
,
# type: ignore
embeddings_tensor
=
None
,
embeddings_tensor
=
None
,
)
)
self
.
lora_alphas
=
lora_alphas
self
.
lora_alphas
=
lora_alphas
if
scaling
is
None
:
if
scaling
is
None
:
self
.
scaling
=
[
self
.
scaling
=
[
# type: ignore
lora_alpha
/
self
.
rank
for
lora_alpha
in
self
.
lora_alphas
lora_alpha
/
self
.
rank
# type: ignore # noqa
for
lora_alpha
in
self
.
lora_alphas
]
]
@
classmethod
@
classmethod
def
pack
(
cls
,
loras
:
List
[
"LoRALayerWeights"
])
->
"PackedLoRALayerWeights"
:
def
pack
(
cls
,
loras
:
List
[
Optional
[
"LoRALayerWeights"
]]
)
->
"PackedLoRALayerWeights"
:
"""Pack a list of LoRAs into a single LoRA.
"""Pack a list of LoRAs into a single LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
If LoRA is None, it signifies that the submodule does not have a LoRA.
...
@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
...
@@ -136,16 +139,19 @@ class PackedLoRALayerWeights(LoRALayerWeights):
[
lora
.
lora_alpha
if
lora
is
not
None
else
None
for
lora
in
loras
],
[
lora
.
lora_alpha
if
lora
is
not
None
else
None
for
lora
in
loras
],
[
lora
.
lora_a
if
lora
is
not
None
else
None
for
lora
in
loras
],
[
lora
.
lora_a
if
lora
is
not
None
else
None
for
lora
in
loras
],
[
lora
.
lora_b
if
lora
is
not
None
else
None
for
lora
in
loras
],
[
lora
.
lora_b
if
lora
is
not
None
else
None
for
lora
in
loras
],
scaling
=
[
1
if
lora
is
not
None
else
None
for
lora
in
loras
])
scaling
=
[
1
if
lora
is
not
None
else
None
# type: ignore
for
lora
in
loras
])
return
obj
return
obj
def
optimize
(
self
)
->
"PackedLoRALayerWeights"
:
def
optimize
(
self
)
->
"PackedLoRALayerWeights"
:
"""Optimize the LoRA by merging the scaling into lora_b."""
"""Optimize the LoRA by merging the scaling into lora_b."""
for
i
in
range
(
len
(
self
.
lora_b
)):
for
i
in
range
(
len
(
self
.
lora_b
)):
if
self
.
scaling
[
i
]
==
1
or
self
.
lora_b
[
i
]
is
None
:
if
self
.
scaling
[
i
]
==
1
or
self
.
lora_b
[
i
]
is
None
:
# type: ignore
continue
continue
self
.
lora_b
[
i
]
*=
self
.
scaling
[
i
]
self
.
lora_b
[
i
]
*=
self
.
scaling
[
i
]
# type: ignore
self
.
scaling
[
i
]
=
1
self
.
scaling
[
i
]
=
1
# type: ignore
return
self
return
self
@
property
@
property
...
...
vllm/lora/models.py
View file @
1591c68f
...
@@ -3,7 +3,7 @@ import json
...
@@ -3,7 +3,7 @@ import json
import
math
import
math
import
os
import
os
import
re
import
re
from
typing
import
Callable
,
Dict
,
Hashable
,
List
,
Optional
,
Tuple
,
Type
from
typing
import
Callable
,
Dict
,
List
,
Optional
,
Tuple
,
Type
import
safetensors.torch
import
safetensors.torch
import
torch
import
torch
...
@@ -11,10 +11,10 @@ from torch import nn
...
@@ -11,10 +11,10 @@ from torch import nn
from
vllm.config
import
LoRAConfig
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
LoRAMapping
,
from_layer
,
from
vllm.lora.layers
import
BaseLayerWithLoRA
,
LoRAMapping
from_layer_logits_processor
)
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.lora
import
LoRALayerWeights
,
PackedLoRALayerWeights
from
vllm.lora.utils
import
parse_fine_tuned_lora_name
,
replace_submodule
from
vllm.lora.utils
import
(
from_layer
,
from_layer_logits_processor
,
parse_fine_tuned_lora_name
,
replace_submodule
)
from
vllm.utils
import
LRUCache
,
is_pin_memory_available
from
vllm.utils
import
LRUCache
,
is_pin_memory_available
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -53,44 +53,46 @@ def convert_mapping(
...
@@ -53,44 +53,46 @@ def convert_mapping(
embeddings.
embeddings.
indices_len: List of lengths of the above tensors.
indices_len: List of lengths of the above tensors.
"""
"""
ind
ices
=
list
(
mapping
.
index_mapping
).
copy
()
ind
ex_mapping_indices
:
List
[
int
]
=
list
(
mapping
.
index_mapping
).
copy
()
embedding_indices
=
indices
.
copy
()
embedding_indices
=
index_mapping_
indices
.
copy
()
lora_indices
=
indices
.
copy
()
lora_indices
=
index_mapping_
indices
.
copy
()
prompt_mapping
=
[
prompt_mapping
:
List
[
int
]
=
[
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
lora_index_to_id
.
index
(
x
)
if
x
>
0
else
-
1
for
x
in
mapping
.
prompt_mapping
for
x
in
mapping
.
prompt_mapping
]
]
lora_idx
=
None
lora_idx
=
None
for
i
in
range
(
len
(
indices
)):
for
i
in
range
(
len
(
index_mapping_
indices
)):
# TODO index can be slow. optimize
# TODO index can be slow. optimize
lora_idx
=
(
lora_index_to_id
.
index
(
indices
[
i
])
lora_idx
=
(
lora_index_to_id
.
index
(
index_mapping_
indices
[
i
])
if
indices
[
i
]
>
0
else
-
1
)
if
index_mapping_
indices
[
i
]
>
0
else
-
1
)
embedding_indices
[
i
]
=
lora_idx
if
indices
[
i
]
>
0
else
0
embedding_indices
[
i
]
=
lora_idx
if
index_mapping_
indices
[
i
]
>
0
else
0
indices
[
i
]
=
i
index_mapping_
indices
[
i
]
=
i
lora_indices
[
i
]
=
lora_idx
lora_indices
[
i
]
=
lora_idx
indices
=
torch
.
tensor
([
indices
,
lora_indices
,
embedding_indices
],
indices
=
torch
.
tensor
(
dtype
=
torch
.
long
,
[
index_mapping_indices
,
lora_indices
,
embedding_indices
],
device
=
"cuda"
)
dtype
=
torch
.
long
,
prompt_mapping
=
torch
.
tensor
(
prompt_mapping
,
device
=
"cuda"
)
device
=
"cuda"
,
prompt_mapping_tensor
=
torch
.
tensor
(
prompt_mapping
,
dtype
=
torch
.
long
)
device
=
"cuda"
,
dtype
=
torch
.
long
)
embeddings_indices
=
torch
.
stack
([
embeddings_indices
=
torch
.
stack
([
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
extra_vocab_size
,
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
)
indices
[
2
]
*
(
vocab_size
+
extra_vocab_size
)
])
])
embeddings_indices
[
embeddings_indices
==
-
1
]
=
max_loras
-
1
embeddings_indices
[
embeddings_indices
==
-
1
]
=
max_loras
-
1
base_indices
=
indices
[
1
]
base_indices
=
indices
[
1
]
sampler_indices
=
prompt_mapping
sampler_indices
=
prompt_mapping
_tensor
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
=
sampler_indices
.
clone
()
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
[
sampler_indices_padded
==
-
1
]
=
max_loras
-
1
sampler_indices_padded
=
(
sampler_indices_padded
=
(
torch
.
arange
(
torch
.
arange
(
0
,
len
(
sampler_indices_padded
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
+
0
,
len
(
sampler_indices_padded
),
device
=
"cuda"
,
dtype
=
torch
.
long
)
+
(
sampler_indices_padded
*
len
(
sampler_indices_padded
)))
(
sampler_indices_padded
*
len
(
sampler_indices_padded
)))
indices_len
=
(
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
indices_len
=
[
sampler_indices_padded
.
shape
[
-
1
],
base_indices
.
shape
[
-
1
],
sampler_indices
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
])
sampler_indices_padded
.
shape
[
-
1
],
embeddings_indices
.
shape
[
-
1
]
]
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
return
(
base_indices
,
sampler_indices
,
sampler_indices_padded
,
embeddings_indices
,
indices_len
)
embeddings_indices
,
indices_len
)
...
@@ -149,6 +151,7 @@ class LoRAModel:
...
@@ -149,6 +151,7 @@ class LoRAModel:
if
module_name
not
in
loras
:
if
module_name
not
in
loras
:
lora_embeddings_tensor
=
None
lora_embeddings_tensor
=
None
if
embeddings
:
if
embeddings
:
assert
embedding_modules
is
not
None
embeddings_module
=
next
(
embeddings_module
=
next
(
(
k
for
k
in
embedding_modules
if
k
in
module_name
),
(
k
for
k
in
embedding_modules
if
k
in
module_name
),
None
)
None
)
...
@@ -171,6 +174,7 @@ class LoRAModel:
...
@@ -171,6 +174,7 @@ class LoRAModel:
else
:
else
:
loras
[
module_name
].
lora_b
=
tensor
.
to
(
device
=
device
,
loras
[
module_name
].
lora_b
=
tensor
.
to
(
device
=
device
,
dtype
=
dtype
).
t
()
dtype
=
dtype
).
t
()
assert
embedding_padding_modules
is
not
None
if
any
(
name
in
module_name
if
any
(
name
in
module_name
for
name
in
embedding_padding_modules
for
name
in
embedding_padding_modules
)
and
target_embedding_padding
is
not
None
:
)
and
target_embedding_padding
is
not
None
:
...
@@ -295,11 +299,10 @@ class LoRAModelManager:
...
@@ -295,11 +299,10 @@ class LoRAModelManager:
self
.
max_num_batched_tokens
,
self
.
max_num_batched_tokens
,
dtype
=
torch
.
long
,
dtype
=
torch
.
long
,
device
=
"cuda"
)
device
=
"cuda"
)
self
.
offsets
=
[]
# 4 is the number of indicies tensors defined above
# 4 is the number of indicies tensors defined above
# base_indices, sampler_indices, sampler_indices_padded,
# base_indices, sampler_indices, sampler_indices_padded,
# embeddings_indices
# embeddings_indices
self
.
indices_len
=
[
None
]
*
4
self
.
indices_len
:
List
[
Optional
[
int
]]
=
[
None
]
*
4
self
.
model
:
nn
.
Module
=
model
self
.
model
:
nn
.
Module
=
model
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
if
hasattr
(
self
.
model
,
"supported_lora_modules"
):
...
@@ -312,7 +315,7 @@ class LoRAModelManager:
...
@@ -312,7 +315,7 @@ class LoRAModelManager:
self
.
_registered_loras
:
Dict
[
int
,
LoRAModel
]
=
{}
self
.
_registered_loras
:
Dict
[
int
,
LoRAModel
]
=
{}
# Dict instead of a Set for compatibility with LRUCache.
# Dict instead of a Set for compatibility with LRUCache.
self
.
_active_loras
:
Dict
[
int
,
None
]
=
{}
self
.
_active_loras
:
Dict
[
int
,
None
]
=
{}
self
.
_last_mapping
=
None
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
_create_lora_modules
()
self
.
model
.
lora_manager
=
self
self
.
model
.
lora_manager
=
self
...
@@ -342,8 +345,8 @@ class LoRAModelManager:
...
@@ -342,8 +345,8 @@ class LoRAModelManager:
index
,
_
=
first_free_slot
index
,
_
=
first_free_slot
self
.
_active_loras
[
lora_id
]
=
None
self
.
_active_loras
[
lora_id
]
=
None
lora_model
=
self
.
_registered_loras
[
lora_id
]
lora_model
=
self
.
_registered_loras
[
lora_id
]
logger
.
debug
(
logger
.
debug
(
"Activating LoRA. int id: %d, slot index: %d"
,
f
"Activating LoRA. int id:
{
lora_model
.
id
}
,
slot index:
{
index
}
"
)
lora_model
.
id
,
index
)
self
.
lora_index_to_id
[
index
]
=
lora_model
.
id
self
.
lora_index_to_id
[
index
]
=
lora_model
.
id
for
module_name
,
module
in
self
.
modules
.
items
():
for
module_name
,
module
in
self
.
modules
.
items
():
module_lora
=
lora_model
.
get_lora
(
module_name
)
module_lora
=
lora_model
.
get_lora
(
module_name
)
...
@@ -370,7 +373,7 @@ class LoRAModelManager:
...
@@ -370,7 +373,7 @@ class LoRAModelManager:
return
True
return
True
return
False
return
False
def
_add_lora
(
self
,
lora
:
LoRAModel
)
->
bool
:
def
_add_lora
(
self
,
lora
:
LoRAModel
):
self
.
_create_merged_loras_inplace
(
lora
)
self
.
_create_merged_loras_inplace
(
lora
)
self
.
_registered_loras
[
lora
.
id
]
=
lora
self
.
_registered_loras
[
lora
.
id
]
=
lora
...
@@ -418,7 +421,7 @@ class LoRAModelManager:
...
@@ -418,7 +421,7 @@ class LoRAModelManager:
def
get_lora
(
self
,
lora_id
:
int
)
->
Optional
[
LoRAModel
]:
def
get_lora
(
self
,
lora_id
:
int
)
->
Optional
[
LoRAModel
]:
return
self
.
_registered_loras
.
get
(
lora_id
,
None
)
return
self
.
_registered_loras
.
get
(
lora_id
,
None
)
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
"""Remove all LoRAModels from the manager."""
"""Remove all LoRAModels from the manager."""
self
.
_registered_loras
.
clear
()
self
.
_registered_loras
.
clear
()
self
.
lora_index_to_id
=
[
None
]
*
self
.
lora_slots
self
.
lora_index_to_id
=
[
None
]
*
self
.
lora_slots
...
@@ -467,6 +470,7 @@ class LoRAModelManager:
...
@@ -467,6 +470,7 @@ class LoRAModelManager:
continue
continue
parts
=
module_name
.
split
(
"."
)
parts
=
module_name
.
split
(
"."
)
if
module_name
not
in
self
.
packed_modules
:
if
module_name
not
in
self
.
packed_modules
:
assert
embedding_modules
is
not
None
if
parts
[
-
1
]
in
embedding_modules
:
if
parts
[
-
1
]
in
embedding_modules
:
input_dim
=
(
module
.
base_layer
.
org_vocab_size
+
input_dim
=
(
module
.
base_layer
.
org_vocab_size
+
self
.
lora_config
.
lora_extra_vocab_size
if
self
.
lora_config
.
lora_extra_vocab_size
if
...
@@ -500,7 +504,7 @@ class LoRAModelManager:
...
@@ -500,7 +504,7 @@ class LoRAModelManager:
else
:
else
:
parts
=
module_name
.
split
(
"."
)
parts
=
module_name
.
split
(
"."
)
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
subloras
=
[]
subloras
:
List
[
Optional
[
"LoRALayerWeights"
]]
=
[]
for
i
,
r
in
enumerate
(
replacements
):
for
i
,
r
in
enumerate
(
replacements
):
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
module_name
+
"."
+
r
,
module_name
+
"."
+
r
,
...
@@ -538,7 +542,7 @@ class LoRAModelManager:
...
@@ -538,7 +542,7 @@ class LoRAModelManager:
def
_create_merged_loras_inplace
(
self
,
lora_model
:
LoRAModel
)
->
None
:
def
_create_merged_loras_inplace
(
self
,
lora_model
:
LoRAModel
)
->
None
:
for
module_name
,
new_module_names
in
self
.
packed_modules
.
items
():
for
module_name
,
new_module_names
in
self
.
packed_modules
.
items
():
replacement_loras
=
[]
replacement_loras
:
List
[
Optional
[
LoRALayerWeights
]]
=
[]
has_replacement
=
False
has_replacement
=
False
for
r
in
new_module_names
:
for
r
in
new_module_names
:
lora
=
lora_model
.
get_lora
(
r
)
lora
=
lora_model
.
get_lora
(
r
)
...
@@ -557,13 +561,13 @@ class LoRAModelManager:
...
@@ -557,13 +561,13 @@ class LoRAModelManager:
class
LoRALRUCache
(
LRUCache
[
LoRAModel
]):
class
LoRALRUCache
(
LRUCache
[
LoRAModel
]):
def
__init__
(
self
,
capacity
:
int
,
deactivate_lora_fn
:
Callable
[[
Hashable
],
def
__init__
(
self
,
capacity
:
int
,
deactivate_lora_fn
:
Callable
[[
int
],
None
]):
bool
]):
super
().
__init__
(
capacity
)
super
().
__init__
(
capacity
)
self
.
deactivate_lora_fn
=
deactivate_lora_fn
self
.
deactivate_lora_fn
=
deactivate_lora_fn
def
_on_remove
(
self
,
key
:
Hashable
,
value
:
LoRAModel
):
def
_on_remove
(
self
,
key
:
int
,
value
:
LoRAModel
):
logger
.
debug
(
f
"Removing LoRA. int id:
{
key
}
"
)
logger
.
debug
(
"Removing LoRA. int id:
%d"
,
key
)
self
.
deactivate_lora_fn
(
key
)
self
.
deactivate_lora_fn
(
key
)
return
super
().
_on_remove
(
key
,
value
)
return
super
().
_on_remove
(
key
,
value
)
...
...
vllm/lora/punica.py
View file @
1591c68f
...
@@ -49,6 +49,49 @@ def bgmv(
...
@@ -49,6 +49,49 @@ def bgmv(
punica_kernels
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
punica_kernels
.
dispatch_bgmv
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
)
def
dispatch_bgmv_low_level
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
w_t_all
:
torch
.
Tensor
,
indicies
:
torch
.
LongTensor
,
layer_idx
:
int
,
scale
:
float
,
y_offset
:
int
,
y_slice_size
:
int
):
"""
Same as `bgmv` but you can operate on slices of y.
Pass whole y, define y_offset and y_slice_size.
Semantics:
y[i] += (
x[i].unsqueeze(0)
@ w_t_all[indices[i], layer_idx, :, :].transpose(-1, -2)
* scale
).squeeze(0)
Args:
y: Shape: `[B, H2]`. Output vectors. Will be changed in-place.
x: Shape: `[B, H1]`. Input vectors.
w_t_all: Shape: `[None, L, y_slice_size, H1]`. Column partition of
all of the transposed LoRA matrices.
indicies: Shape: `[B]`. Indices of the LoRA weights.
layer_idx: Layer index of LoRA weights.
scale: Scaling factor.
y_offset: Offset to apply to the starting column of y.
y_slice_size: Size of the y column slice.
"""
try
:
import
vllm._punica_C
as
punica_kernels
except
ImportError
as
e
:
_raise_import_error
(
e
)
punica_kernels
.
dispatch_bgmv_low_level
(
y
,
x
,
w_t_all
,
indicies
,
layer_idx
,
scale
,
x
.
size
(
1
),
y_slice_size
,
y_offset
,
)
def
add_lora
(
y
:
torch
.
Tensor
,
def
add_lora
(
y
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
wa_t_all
:
torch
.
Tensor
,
...
...
vllm/lora/utils.py
View file @
1591c68f
from
typing
import
Tupl
e
from
typing
import
List
,
Optional
,
Set
,
Tuple
,
Typ
e
from
torch
import
nn
from
torch
import
nn
from
transformers
import
PretrainedConfig
from
vllm.config
import
LoRAConfig
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.lora.fully_sharded_layers
import
(
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLora
,
RowParallelLinearWithShardedLoRA
)
# being imported for _all_lora_classes below
# yapf conflicts with isort for this block
# yapf: disable
from
vllm.lora.layers
import
(
BaseLayerWithLoRA
,
ColumnParallelLinearWithLoRA
,
LogitsProcessorWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
MergedQKVParallelLinearWithLora
,
QKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
,
VocabParallelEmbeddingWithLoRA
)
# yapf: enable
from
vllm.model_executor.layers.logits_processor
import
LogitsProcessor
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
_all_lora_classes
:
Set
[
Type
[
BaseLayerWithLoRA
]]
=
{
VocabParallelEmbeddingWithLoRA
,
ColumnParallelLinearWithLoRA
,
MergedColumnParallelLinearWithLoRA
,
QKVParallelLinearWithLora
,
MergedQKVParallelLinearWithLora
,
RowParallelLinearWithLoRA
,
LogitsProcessorWithLoRA
,
ColumnParallelLinearWithShardedLoRA
,
MergedColumnParallelLinearWithShardedLoRA
,
MergedQKVParallelLinearWithShardedLora
,
RowParallelLinearWithShardedLoRA
}
def
from_layer
(
layer
:
nn
.
Module
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
packed_modules_list
:
List
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
)
->
nn
.
Module
:
for
lora_cls
in
_all_lora_classes
:
# specifying kwargs so they can be easily accessed in decorator
if
lora_cls
.
can_replace_layer
(
source_layer
=
layer
,
lora_config
=
lora_config
,
packed_modules_list
=
packed_modules_list
,
model_config
=
model_config
):
ret
=
lora_cls
(
layer
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
return
layer
def
from_layer_logits_processor
(
layer
:
LogitsProcessor
,
lm_head
:
ParallelLMHead
,
max_loras
:
int
,
lora_config
:
LoRAConfig
,
model_config
:
Optional
[
PretrainedConfig
]
=
None
,
)
->
LogitsProcessorWithLoRA
:
ret
=
LogitsProcessorWithLoRA
(
layer
,
lm_head
.
embedding_dim
,
lm_head
.
weight
.
dtype
,
lm_head
.
weight
.
device
)
ret
.
create_lora_weights
(
max_loras
,
lora_config
,
model_config
)
return
ret
def
replace_submodule
(
model
:
nn
.
Module
,
module_name
:
str
,
def
replace_submodule
(
model
:
nn
.
Module
,
module_name
:
str
,
new_module
:
nn
.
Module
)
->
nn
.
Module
:
new_module
:
nn
.
Module
)
->
nn
.
Module
:
...
...
vllm/lora/worker_manager.py
View file @
1591c68f
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
abc
import
ABC
,
abstractmethod
,
abstractproperty
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Set
,
Type
from
typing
import
Any
,
Dict
,
List
,
Set
,
Type
import
torch
import
torch
...
@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@@ -37,7 +37,7 @@ class AbstractWorkerLoRAManager(ABC):
...
...
@
abstractmethod
@
abstractmethod
def
set_active_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
],
def
set_active_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
],
lora_mapping
:
LoRAMapping
)
->
None
:
lora_mapping
:
LoRAMapping
)
->
None
:
...
...
...
@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
...
@@ -54,7 +54,7 @@ class AbstractWorkerLoRAManager(ABC):
...
...
@
abstractmethod
@
abstractmethod
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
...
...
@
abstractmethod
@
abstractmethod
...
@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -81,10 +81,11 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
embedding_padding_modules
:
List
[
str
],
embedding_padding_modules
:
List
[
str
],
lora_model_cls
:
Type
[
LoRAModel
]
=
LoRAModel
,
lora_model_cls
:
Type
[
LoRAModel
]
=
LoRAModel
,
):
):
self
.
_lora_manager
:
Optional
[
LoRAModelManager
]
=
None
self
.
_lora_model_cls
=
lora_model_cls
self
.
_lora_model_cls
=
lora_model_cls
self
.
embedding_modules
=
embedding_modules
self
.
embedding_modules
=
embedding_modules
self
.
embedding_padding_modules
=
embedding_padding_modules
self
.
embedding_padding_modules
=
embedding_padding_modules
# Lazily initialized by create_lora_manager.
self
.
_lora_manager
:
LoRAModelManager
super
().
__init__
(
max_num_seqs
,
max_num_batched_tokens
,
vocab_size
,
super
().
__init__
(
max_num_seqs
,
max_num_batched_tokens
,
vocab_size
,
lora_config
,
device
)
lora_config
,
device
)
...
@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -104,7 +105,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
lora_manager_cls
=
self
.
_lora_manager_cls
,
lora_manager_cls
=
self
.
_lora_manager_cls
,
)
)
self
.
_lora_manager
:
LoRAModelManager
=
lora_manager
self
.
_lora_manager
=
lora_manager
return
lora_manager
.
model
return
lora_manager
.
model
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
def
set_active_loras
(
self
,
lora_requests
:
Set
[
LoRARequest
],
...
@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
...
@@ -188,7 +189,7 @@ class WorkerLoRAManager(AbstractWorkerLoRAManager):
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
def
remove_lora
(
self
,
lora_id
:
int
)
->
bool
:
return
self
.
_lora_manager
.
remove_lora
(
lora_id
)
return
self
.
_lora_manager
.
remove_lora
(
lora_id
)
def
remove_all_loras
(
self
)
->
bool
:
def
remove_all_loras
(
self
):
self
.
_lora_manager
.
remove_all_loras
()
self
.
_lora_manager
.
remove_all_loras
()
def
list_loras
(
self
)
->
Set
[
int
]:
def
list_loras
(
self
)
->
Set
[
int
]:
...
@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
...
@@ -217,10 +218,10 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
lora_config
=
self
.
lora_config
,
lora_config
=
self
.
lora_config
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
max_num_batched_tokens
=
self
.
max_num_batched_tokens
,
)
)
self
.
_lora_manager
:
LRUCacheLoRAModelManager
=
lora_manager
self
.
_lora_manager
=
lora_manager
return
lora_manager
.
model
return
lora_manager
.
model
def
_apply_loras
(
self
,
lora_requests
:
Lis
t
[
LoRARequest
])
->
None
:
def
_apply_loras
(
self
,
lora_requests
:
Se
t
[
LoRARequest
])
->
None
:
loras_map
=
{
loras_map
=
{
lora_request
.
lora_int_id
:
lora_request
lora_request
.
lora_int_id
:
lora_request
for
lora_request
in
lora_requests
if
lora_request
for
lora_request
in
lora_requests
if
lora_request
...
@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
...
@@ -237,12 +238,14 @@ class LRUCacheWorkerLoRAManager(WorkerLoRAManager):
if
lora_request
.
lora_int_id
not
in
self
.
list_loras
():
if
lora_request
.
lora_int_id
not
in
self
.
list_loras
():
# Remove before we load the new lora to save memory
# Remove before we load the new lora to save memory
if
len
(
self
.
_lora_manager
)
+
1
>
self
.
_lora_manager
.
capacity
:
if
len
(
self
.
_lora_manager
)
+
1
>
self
.
_lora_manager
.
capacity
:
assert
isinstance
(
self
.
_lora_manager
,
LRUCacheLoRAModelManager
)
self
.
_lora_manager
.
remove_oldest_lora
()
self
.
_lora_manager
.
remove_oldest_lora
()
lora
=
self
.
_load_lora
(
lora_request
)
lora
=
self
.
_load_lora
(
lora_request
)
loaded
=
self
.
_lora_manager
.
add_lora
(
lora
)
loaded
=
self
.
_lora_manager
.
add_lora
(
lora
)
else
:
else
:
# If the lora is already loaded, just touch it to
# If the lora is already loaded, just touch it to
# update its position in the caches
# update its position in the caches
loaded
=
self
.
_lora_manager
.
get_lora
(
lora_request
.
lora_int_id
)
loaded
=
self
.
_lora_manager
.
get_lora
(
lora_request
.
lora_int_id
)
is
not
None
self
.
_lora_manager
.
activate_lora
(
lora_request
.
lora_int_id
)
self
.
_lora_manager
.
activate_lora
(
lora_request
.
lora_int_id
)
return
loaded
return
loaded
vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py
View file @
1591c68f
...
@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
...
@@ -61,6 +61,7 @@ def _normalize_json_schema_object(schema: Union[str, dict, BaseModel]) -> dict:
return
schema
return
schema
if
isinstance
(
schema
,
BaseModel
):
if
isinstance
(
schema
,
BaseModel
):
return
schema
.
model_json_schema
()
return
schema
.
model_json_schema
()
raise
AssertionError
(
f
"Unsupported schema type
{
schema
}
"
)
@
lru_cache
@
lru_cache
...
...
vllm/model_executor/guided_decoding/outlines_decoding.py
View file @
1591c68f
...
@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
...
@@ -74,7 +74,8 @@ async def get_outlines_guided_decoding_logits_processor(
result
=
await
loop
.
run_in_executor
(
global_thread_pool
,
result
=
await
loop
.
run_in_executor
(
global_thread_pool
,
_get_cached_logits_processor
,
guide
,
_get_cached_logits_processor
,
guide
,
tokenizer
,
mode
)
tokenizer
,
mode
,
request
.
guided_whitespace_pattern
)
logits_processor
=
copy
(
result
)
logits_processor
=
copy
(
result
)
# reset logits processor's internal state
# reset logits processor's internal state
...
@@ -117,9 +118,10 @@ def _get_guide_and_mode(
...
@@ -117,9 +118,10 @@ def _get_guide_and_mode(
@
lru_cache
(
maxsize
=
32
)
@
lru_cache
(
maxsize
=
32
)
def
_get_cached_logits_processor
(
guide
:
str
,
def
_get_cached_logits_processor
(
guide
:
str
,
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
mode
:
GuidedDecodingMode
):
mode
:
GuidedDecodingMode
,
whitespace_pattern
:
Union
[
str
,
None
]):
if
mode
==
GuidedDecodingMode
.
JSON
:
if
mode
==
GuidedDecodingMode
.
JSON
:
return
JSONLogitsProcessor
(
guide
,
tokenizer
)
return
JSONLogitsProcessor
(
guide
,
tokenizer
,
whitespace_pattern
)
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
elif
mode
==
GuidedDecodingMode
.
REGEX
or
mode
==
GuidedDecodingMode
.
CHOICE
:
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
return
RegexLogitsProcessor
(
guide
,
tokenizer
)
elif
mode
==
GuidedDecodingMode
.
GRAMMAR
:
elif
mode
==
GuidedDecodingMode
.
GRAMMAR
:
...
...
vllm/model_executor/guided_decoding/outlines_logits_processors.py
View file @
1591c68f
...
@@ -18,7 +18,7 @@ import json
...
@@ -18,7 +18,7 @@ import json
import
math
import
math
from
collections
import
defaultdict
from
collections
import
defaultdict
from
functools
import
lru_cache
from
functools
import
lru_cache
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Optional
,
Union
from
typing
import
Callable
,
DefaultDict
,
Dict
,
List
,
Union
import
torch
import
torch
from
outlines.fsm.fsm
import
CFGFSM
,
FSM
,
RegexFSM
from
outlines.fsm.fsm
import
CFGFSM
,
FSM
,
RegexFSM
...
@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
...
@@ -80,10 +80,9 @@ class RegexLogitsProcessor(BaseLogitsProcessor):
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
class
JSONLogitsProcessor
(
RegexLogitsProcessor
):
def
__init__
(
self
,
def
__init__
(
self
,
schema
:
Union
[
str
,
Dict
,
BaseModel
],
schema
:
Union
[
str
,
Dict
,
BaseModel
],
tokenizer
:
PreTrainedTokenizerBase
,
tokenizer
:
PreTrainedTokenizerBase
,
whitespace_pattern
:
Opt
ion
al
[
str
]
=
None
):
whitespace_pattern
:
Un
ion
[
str
,
None
]
):
"""Compile the FSM that drives the JSON-guided generation.
"""Compile the FSM that drives the JSON-guided generation.
Parameters
Parameters
...
...
vllm/model_executor/layers/activation.py
View file @
1591c68f
...
@@ -67,6 +67,9 @@ class GeluAndMul(nn.Module):
...
@@ -67,6 +67,9 @@ class GeluAndMul(nn.Module):
ops
.
gelu_tanh_and_mul
(
out
,
x
)
ops
.
gelu_tanh_and_mul
(
out
,
x
)
return
out
return
out
def
extra_repr
(
self
)
->
str
:
return
f
'approximate=
{
repr
(
self
.
approximate
)
}
'
class
NewGELU
(
nn
.
Module
):
class
NewGELU
(
nn
.
Module
):
...
...
vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json
0 → 100644
View file @
1591c68f
{
"1"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
},
"2"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
},
"4"
:
{
"BLOCK_SIZE_M"
:
16
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
1
},
"8"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
5
},
"16"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"24"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
1
,
"num_warps"
:
8
,
"num_stages"
:
5
},
"32"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"48"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
3
},
"64"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
64
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
4
},
"96"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
2
},
"128"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
64
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
4
,
"num_stages"
:
3
},
"256"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
32
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
4
,
"num_stages"
:
5
},
"512"
:
{
"BLOCK_SIZE_M"
:
64
,
"BLOCK_SIZE_N"
:
128
,
"BLOCK_SIZE_K"
:
256
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
4
,
"num_stages"
:
2
},
"1024"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"1536"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"2048"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
64
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"3072"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
32
,
"num_warps"
:
8
,
"num_stages"
:
4
},
"4096"
:
{
"BLOCK_SIZE_M"
:
128
,
"BLOCK_SIZE_N"
:
256
,
"BLOCK_SIZE_K"
:
128
,
"GROUP_SIZE_M"
:
16
,
"num_warps"
:
8
,
"num_stages"
:
4
}
}
vllm/model_executor/layers/fused_moe/fused_moe.py
View file @
1591c68f
...
@@ -203,14 +203,15 @@ def moe_align_block_size(
...
@@ -203,14 +203,15 @@ def moe_align_block_size(
- The padding ensures that the total number of tokens is now divisible
- The padding ensures that the total number of tokens is now divisible
by block_size for proper block matrix operations.
by block_size for proper block matrix operations.
"""
"""
sorted_ids
=
torch
.
empty
(
max_num_tokens_padded
=
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
)
(
topk_ids
.
numel
()
+
num_experts
*
(
block_size
-
1
),
),
sorted_ids
=
torch
.
empty
((
max_num_tokens_padded
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
expert_ids
=
torch
.
empty
((
topk_ids
.
numel
()
+
num_experts
,
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
device
=
topk_ids
.
device
)
sorted_ids
.
fill_
(
topk_ids
.
numel
())
sorted_ids
.
fill_
(
topk_ids
.
numel
())
max_num_m_blocks
=
triton
.
cdiv
(
max_num_tokens_padded
,
block_size
)
expert_ids
=
torch
.
empty
((
max_num_m_blocks
,
),
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
num_tokens_post_pad
=
torch
.
empty
((
1
),
num_tokens_post_pad
=
torch
.
empty
((
1
),
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
device
=
topk_ids
.
device
)
device
=
topk_ids
.
device
)
...
@@ -220,8 +221,9 @@ def moe_align_block_size(
...
@@ -220,8 +221,9 @@ def moe_align_block_size(
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
def
invoke_fused_moe_kernel
(
A
:
torch
.
Tensor
,
B
:
torch
.
Tensor
,
C
:
torch
.
Tensor
,
B_scale
:
torch
.
Tensor
,
topk_weights
:
torch
.
Tensor
,
A_scale
:
Optional
[
torch
.
Tensor
],
topk_ids
:
torch
.
Tensor
,
B_scale
:
Optional
[
torch
.
Tensor
],
topk_weights
:
torch
.
Tensor
,
topk_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
sorted_token_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
expert_ids
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
num_tokens_post_padded
:
torch
.
Tensor
,
...
@@ -232,10 +234,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
...
@@ -232,10 +234,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
assert
sorted_token_ids
.
stride
(
0
)
==
1
assert
sorted_token_ids
.
stride
(
0
)
==
1
if
not
use_fp8
:
if
not
use_fp8
:
A_scale
=
None
assert
A_scale
is
None
assert
B_scale
is
None
assert
B_scale
is
None
else
:
else
:
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
)
A
,
A_scale
=
ops
.
scaled_fp8_quant
(
A
,
A_scale
)
assert
B_scale
is
not
None
assert
B_scale
is
not
None
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
grid
=
lambda
META
:
(
triton
.
cdiv
(
sorted_token_ids
.
shape
[
0
],
META
[
...
@@ -296,8 +298,8 @@ def get_moe_configs(E: int, N: int,
...
@@ -296,8 +298,8 @@ def get_moe_configs(E: int, N: int,
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
os
.
path
.
dirname
(
os
.
path
.
realpath
(
__file__
)),
"configs"
,
json_file_name
)
if
os
.
path
.
exists
(
config_file_path
):
if
os
.
path
.
exists
(
config_file_path
):
with
open
(
config_file_path
)
as
f
:
with
open
(
config_file_path
)
as
f
:
logger
.
info
(
logger
.
info
(
"Using configuration from %s for MoE layer."
,
f
"Using configuration from
{
config_file_path
}
for MoE layer."
)
config_file_path
)
# If a configuration has been found, return it
# If a configuration has been found, return it
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
return
{
int
(
key
):
val
for
key
,
val
in
json
.
load
(
f
).
items
()}
...
@@ -318,6 +320,8 @@ def fused_moe(
...
@@ -318,6 +320,8 @@ def fused_moe(
use_fp8
:
bool
=
False
,
use_fp8
:
bool
=
False
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
w2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a1_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
a2_scale
:
Optional
[
torch
.
Tensor
]
=
None
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
"""
"""
This function computes a Mixture of Experts (MoE) layer using two sets of
This function computes a Mixture of Experts (MoE) layer using two sets of
...
@@ -430,10 +434,13 @@ def fused_moe(
...
@@ -430,10 +434,13 @@ def fused_moe(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
sorted_token_ids
,
expert_ids
,
num_tokens_post_padded
=
moe_align_block_size
(
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
topk_ids
,
config
[
'BLOCK_SIZE_M'
],
E
)
compute_type
=
(
tl
.
bfloat16
if
hidden_states
.
dtype
==
torch
.
bfloat16
else
tl
.
float16
)
invoke_fused_moe_kernel
(
hidden_states
,
invoke_fused_moe_kernel
(
hidden_states
,
w1
,
w1
,
intermediate_cache1
,
intermediate_cache1
,
a1_scale
,
w1_scale
,
w1_scale
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
...
@@ -443,7 +450,7 @@ def fused_moe(
...
@@ -443,7 +450,7 @@ def fused_moe(
False
,
False
,
topk_ids
.
shape
[
1
],
topk_ids
.
shape
[
1
],
config
,
config
,
compute_type
=
tl
.
float16
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
use_fp8
=
use_fp8
)
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
ops
.
silu_and_mul
(
intermediate_cache2
,
intermediate_cache1
.
view
(
-
1
,
N
))
...
@@ -451,6 +458,7 @@ def fused_moe(
...
@@ -451,6 +458,7 @@ def fused_moe(
invoke_fused_moe_kernel
(
intermediate_cache2
,
invoke_fused_moe_kernel
(
intermediate_cache2
,
w2
,
w2
,
intermediate_cache3
,
intermediate_cache3
,
a2_scale
,
w2_scale
,
w2_scale
,
topk_weights
,
topk_weights
,
topk_ids
,
topk_ids
,
...
@@ -460,7 +468,7 @@ def fused_moe(
...
@@ -460,7 +468,7 @@ def fused_moe(
True
,
True
,
1
,
1
,
config
,
config
,
compute_type
=
tl
.
float16
,
compute_type
=
compute_type
,
use_fp8
=
use_fp8
)
use_fp8
=
use_fp8
)
if
inplace
:
if
inplace
:
...
...
vllm/model_executor/layers/layernorm.py
View file @
1591c68f
...
@@ -64,3 +64,8 @@ class RMSNorm(nn.Module):
...
@@ -64,3 +64,8 @@ class RMSNorm(nn.Module):
self
.
variance_epsilon
,
self
.
variance_epsilon
,
)
)
return
out
return
out
def
extra_repr
(
self
)
->
str
:
s
=
f
"hidden_size=
{
self
.
weight
.
data
.
size
(
0
)
}
"
s
+=
f
", eps=
{
self
.
variance_epsilon
}
"
return
s
vllm/model_executor/layers/linear.py
View file @
1591c68f
from
abc
import
ABC
,
abstractmethod
from
abc
import
abstractmethod
from
typing
import
List
,
Optional
from
typing
import
List
,
Optional
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch
import
nn
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
from
vllm.distributed
import
(
divide
,
get_tensor_model_parallel_rank
,
...
@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
...
@@ -12,6 +11,8 @@ from vllm.distributed import (divide, get_tensor_model_parallel_rank,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_gather
,
tensor_model_parallel_all_reduce
)
tensor_model_parallel_all_reduce
)
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
,
QuantizeMethodBase
)
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.model_executor.utils
import
set_weight_attrs
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
...
@@ -28,7 +29,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
...
@@ -28,7 +29,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
class
LinearMethodBase
(
ABC
):
class
LinearMethodBase
(
QuantizeMethodBase
):
"""Base class for different (maybe quantized) linear methods."""
"""Base class for different (maybe quantized) linear methods."""
@
abstractmethod
@
abstractmethod
...
@@ -53,22 +54,15 @@ class LinearMethodBase(ABC):
...
@@ -53,22 +54,15 @@ class LinearMethodBase(ABC):
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class
UnquantizedLinearMethod
(
LinearMethodBase
):
class
UnquantizedLinearMethod
(
LinearMethodBase
):
"""Linear method without quantization.
"""Linear method without quantization.
...
@@ -96,10 +90,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -96,10 +90,10 @@ class UnquantizedLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"weight"
,
weight
)
layer
.
register_parameter
(
"weight"
,
weight
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
set_weight_attrs
(
weight
,
extra_weight_attrs
)
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
weight
=
layer
.
weight
weight
=
layer
.
weight
if
self
.
separate_bias_add
:
if
self
.
separate_bias_add
:
if
bias
is
not
None
:
if
bias
is
not
None
:
...
@@ -116,8 +110,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
...
@@ -116,8 +110,8 @@ class UnquantizedLinearMethod(LinearMethodBase):
return
F
.
linear
(
x
,
weight
,
bias
)
return
F
.
linear
(
x
,
weight
,
bias
)
class
Replicated
Linear
(
torch
.
nn
.
Module
):
class
Linear
Base
(
torch
.
nn
.
Module
):
"""
Replicated
linear layer.
"""
Base
linear layer.
Args:
Args:
input_size: input dimension of the linear layer.
input_size: input dimension of the linear layer.
...
@@ -125,17 +119,16 @@ class ReplicatedLinear(torch.nn.Module):
...
@@ -125,17 +119,16 @@ class ReplicatedLinear(torch.nn.Module):
bias: If true, add bias.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
"""
def
__init__
(
def
__init__
(
self
,
self
,
input_size
:
int
,
input_size
:
int
,
output_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -146,12 +139,46 @@ class ReplicatedLinear(torch.nn.Module):
...
@@ -146,12 +139,46 @@ class ReplicatedLinear(torch.nn.Module):
if
params_dtype
is
None
:
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
self
.
params_dtype
=
params_dtype
if
linear_method
is
None
:
if
quant_config
is
None
:
linear_method
=
UnquantizedLinearMethod
()
self
.
quant_method
:
Optional
[
self
.
linear_method
=
linear_method
QuantizeMethodBase
]
=
UnquantizedLinearMethod
()
self
.
linear_method
.
create_weights
(
self
,
self
.
input_size
,
else
:
[
self
.
output_size
],
self
.
input_size
,
self
.
quant_method
=
quant_config
.
get_quant_method
(
self
)
self
.
output_size
,
self
.
params_dtype
)
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
raise
NotImplementedError
class
ReplicatedLinear
(
LinearBase
):
"""Replicated linear layer.
Args:
input_size: input dimension of the linear layer.
output_size: output dimension of the linear layer.
bias: If true, add bias.
skip_bias_add: If true, skip adding bias but instead return it.
params_dtype: Data type for the parameters.
quant_config: Quantization configure.
"""
def
__init__
(
self
,
input_size
:
int
,
output_size
:
int
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
# All the linear layer supports quant method.
assert
self
.
quant_method
is
not
None
self
.
quant_method
.
create_weights
(
self
,
self
.
input_size
,
[
self
.
output_size
],
self
.
input_size
,
self
.
output_size
,
self
.
params_dtype
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
torch
.
empty
(
self
.
output_size
,
dtype
=
self
.
params_dtype
))
...
@@ -161,12 +188,19 @@ class ReplicatedLinear(torch.nn.Module):
...
@@ -161,12 +188,19 @@ class ReplicatedLinear(torch.nn.Module):
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
def
forward
(
self
,
x
:
torch
.
Tensor
)
->
torch
.
Tensor
:
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
output
=
self
.
linear_method
.
apply_weights
(
self
,
x
,
bias
)
assert
self
.
quant_method
is
not
None
output
=
self
.
quant_method
.
apply
(
self
,
x
,
bias
)
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
return
s
class
ColumnParallelLinear
(
torch
.
nn
.
Modul
e
):
class
ColumnParallelLinear
(
LinearBas
e
):
"""Linear layer with column parallelism.
"""Linear layer with column parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
The linear layer is defined as Y = XA + b. A is parallelized along
...
@@ -183,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -183,7 +217,7 @@ class ColumnParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations. we
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
output_sizes: list of output sizes packed into one output, like for QKV
output_sizes: list of output sizes packed into one output, like for QKV
the list would be size 3.
the list would be size 3.
"""
"""
...
@@ -196,34 +230,28 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -196,34 +230,28 @@ class ColumnParallelLinear(torch.nn.Module):
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
output_sizes
:
Optional
[
List
[
int
]]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
quant_config
)
# Keep input parameters
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
gather_output
=
gather_output
self
.
gather_output
=
gather_output
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
output_size_per_partition
=
divide
(
output_size
,
tp_size
)
self
.
output_size_per_partition
=
divide
(
output_size
,
tp_size
)
self
.
skip_bias_add
=
skip_bias_add
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
if
linear_method
is
None
:
linear_method
=
UnquantizedLinearMethod
()
if
output_sizes
is
None
:
if
output_sizes
is
None
:
output_sizes
=
[
output_size
]
output_sizes
=
[
output_size
]
self
.
linear_method
=
linear_method
# All the linear layer supports quant method.
self
.
linear_method
.
create_weights
(
self
,
assert
self
.
quant_method
is
not
None
self
.
input_size
,
self
.
quant_method
.
create_weights
(
self
,
[
x
//
tp_size
for
x
in
output_sizes
],
self
.
input_size
,
self
.
input_size
,
[
x
//
tp_size
for
x
in
output_sizes
],
self
.
output_size
,
self
.
input_size
,
self
.
params_dtype
,
self
.
output_size
,
weight_loader
=
self
.
weight_loader
)
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
if
bias
:
if
bias
:
self
.
bias
=
Parameter
(
self
.
bias
=
Parameter
(
torch
.
empty
(
self
.
output_size_per_partition
,
torch
.
empty
(
self
.
output_size_per_partition
,
...
@@ -237,6 +265,10 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -237,6 +265,10 @@ class ColumnParallelLinear(torch.nn.Module):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -245,6 +277,12 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -245,6 +277,12 @@ class ColumnParallelLinear(torch.nn.Module):
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
shard_id
=
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
...
@@ -255,7 +293,8 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -255,7 +293,8 @@ class ColumnParallelLinear(torch.nn.Module):
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
bias
=
self
.
bias
if
not
self
.
skip_bias_add
else
None
# Matrix multiply.
# Matrix multiply.
output_parallel
=
self
.
linear_method
.
apply_weights
(
self
,
input_
,
bias
)
assert
self
.
quant_method
is
not
None
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_
,
bias
)
if
self
.
gather_output
:
if
self
.
gather_output
:
# All-gather across the partitions.
# All-gather across the partitions.
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
output
=
tensor_model_parallel_all_gather
(
output_parallel
)
...
@@ -264,6 +303,14 @@ class ColumnParallelLinear(torch.nn.Module):
...
@@ -264,6 +303,14 @@ class ColumnParallelLinear(torch.nn.Module):
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
output_bias
=
self
.
bias
if
self
.
skip_bias_add
else
None
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"in_features=
{
self
.
input_size
}
"
s
+=
f
", output_features=
{
self
.
output_size_per_partition
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
get_tensor_model_parallel_world_size
()
}
"
s
+=
f
", gather_output=
{
self
.
gather_output
}
"
return
s
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
class
MergedColumnParallelLinear
(
ColumnParallelLinear
):
"""Packed linear layers with column parallelism.
"""Packed linear layers with column parallelism.
...
@@ -283,7 +330,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -283,7 +330,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -294,13 +341,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -294,13 +341,18 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
gather_output
:
bool
=
False
,
gather_output
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
self
.
output_sizes
=
output_sizes
self
.
output_sizes
=
output_sizes
tp_size
=
get_tensor_model_parallel_world_size
()
tp_size
=
get_tensor_model_parallel_world_size
()
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
assert
all
(
output_size
%
tp_size
==
0
for
output_size
in
output_sizes
)
<<<<<<<
HEAD
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
gather_output
,
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
gather_output
,
skip_bias_add
,
params_dtype
,
linear_method
,
skip_bias_add
,
params_dtype
,
linear_method
,
=======
super
().
__init__
(
input_size
,
sum
(
output_sizes
),
bias
,
gather_output
,
skip_bias_add
,
params_dtype
,
quant_config
,
>>>>>>>
v0
.
4.2
self
.
output_sizes
)
self
.
output_sizes
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
...
@@ -311,7 +363,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -311,7 +363,12 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param_data
=
param
.
data
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
# Loaded weight is already packed.
if
output_dim
is
None
:
if
output_dim
is
None
:
...
@@ -325,14 +382,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -325,14 +382,13 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
current_shard_offset
+=
output_size
current_shard_offset
+=
output_size
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantization.
# If quantized, we need to adjust the offset and size to account
# If quantized, we need to adjust the offset and size to account
# for the packing.
# for the packing.
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
...
@@ -347,15 +403,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -347,15 +403,14 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
output_dim
is
not
None
:
if
output_dim
is
not
None
:
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_offset
=
sum
(
self
.
output_sizes
[:
loaded_shard_id
])
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
shard_size
=
self
.
output_sizes
[
loaded_shard_id
]
//
tp_size
# Special case for quantization.
# If quantized, we need to adjust the offset and size to account
# If quantized, we need to adjust the offset and size to account
# for the packing.
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# Special case for Marlin.
# If marlin, we need to adjust the offset and size to
# account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
...
@@ -368,11 +423,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
...
@@ -368,11 +423,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
# Special case for AQLM codebooks.
elif
is_metadata
:
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_size
=
loaded_weight
.
shape
[
0
]
shard_offset
=
loaded_shard_id
*
shard_size
shard_offset
=
loaded_shard_id
*
shard_size
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
param_data
=
param_data
.
narrow
(
0
,
shard_offset
,
shard_size
)
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
if
not
ignore_warning
:
...
@@ -413,7 +474,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -413,7 +474,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias can be fused with other element-wise operations. we
bias can be fused with other element-wise operations. we
skip adding bias but instead return it.
skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -425,7 +486,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -425,7 +486,7 @@ class QKVParallelLinear(ColumnParallelLinear):
bias
:
bool
=
True
,
bias
:
bool
=
True
,
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
head_size
=
head_size
self
.
head_size
=
head_size
...
@@ -453,8 +514,12 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -453,8 +514,12 @@ class QKVParallelLinear(ColumnParallelLinear):
]
]
super
().
__init__
(
input_size
,
output_size
,
bias
,
False
,
skip_bias_add
,
super
().
__init__
(
input_size
,
output_size
,
bias
,
False
,
skip_bias_add
,
<<<<<<<
HEAD
params_dtype
,
linear_method
,
output_sizes
)
params_dtype
,
linear_method
,
output_sizes
)
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
=======
params_dtype
,
quant_config
,
output_sizes
)
>>>>>>>
v0
.
4.2
def
weight_loader
(
self
,
def
weight_loader
(
self
,
param
:
Parameter
,
param
:
Parameter
,
...
@@ -462,7 +527,11 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -462,7 +527,11 @@ class QKVParallelLinear(ColumnParallelLinear):
loaded_shard_id
:
Optional
[
str
]
=
None
):
loaded_shard_id
:
Optional
[
str
]
=
None
):
param_data
=
param
.
data
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
# Special case for AQLM codebooks.
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
is_metadata
=
getattr
(
param
,
"is_metadata"
,
False
)
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
if
loaded_shard_id
is
None
:
if
loaded_shard_id
is
None
:
# Loaded weight is already packed.
# Loaded weight is already packed.
...
@@ -480,14 +549,14 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -480,14 +549,14 @@ class QKVParallelLinear(ColumnParallelLinear):
]
]
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# If quantized, we need to adjust the offset and size to account
# for the packing.
# for the packing.
if
packed_dim
==
output_dim
:
if
packed_dim
==
output_dim
:
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to
# Special case for Marlin.
# account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
...
@@ -509,6 +578,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -509,6 +578,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_offset
=
(
self
.
num_heads
+
shard_offset
=
(
self
.
num_heads
+
self
.
num_kv_heads
)
*
self
.
head_size
self
.
num_kv_heads
)
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
shard_size
=
self
.
num_kv_heads
*
self
.
head_size
# Special case for Quantized Weights.
# If quantized, we need to adjust the offset and size to account
# If quantized, we need to adjust the offset and size to account
# for the packing.
# for the packing.
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
...
@@ -516,8 +586,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -516,8 +586,7 @@ class QKVParallelLinear(ColumnParallelLinear):
shard_size
=
shard_size
//
param
.
pack_factor
shard_size
=
shard_size
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
shard_offset
=
shard_offset
//
param
.
pack_factor
# If marlin, we need to adjust the offset and size to
# Special case for Marlin.
# account for the tiling.
shard_size
,
shard_offset
=
adjust_marlin_shard
(
shard_size
,
shard_offset
=
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
)
param
,
shard_size
,
shard_offset
)
...
@@ -534,12 +603,17 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -534,12 +603,17 @@ class QKVParallelLinear(ColumnParallelLinear):
start_idx
=
shard_id
*
shard_size
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
shard_size
)
# Special case for for AQLM codebooks.
elif
is_metadata
:
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
# metadata indicates fixed size concatenated along dim 0
shard_size
=
loaded_weight
.
shape
[
0
]
shard_size
=
loaded_weight
.
shape
[
0
]
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
shard_index
=
[
"q"
,
"k"
,
"v"
].
index
(
loaded_shard_id
)
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
param_data
=
param_data
.
narrow
(
0
,
shard_index
*
shard_size
,
shard_size
)
shard_size
)
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
loaded_shard_id
)
else
:
else
:
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
ignore_warning
=
getattr
(
param
,
"ignore_warning"
,
False
)
if
not
ignore_warning
:
if
not
ignore_warning
:
...
@@ -559,7 +633,7 @@ class QKVParallelLinear(ColumnParallelLinear):
...
@@ -559,7 +633,7 @@ class QKVParallelLinear(ColumnParallelLinear):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
class
RowParallelLinear
(
torch
.
nn
.
Modul
e
):
class
RowParallelLinear
(
LinearBas
e
):
"""Linear layer with row parallelism.
"""Linear layer with row parallelism.
The linear layer is defined as Y = XA + b. A is parallelized along
The linear layer is defined as Y = XA + b. A is parallelized along
...
@@ -582,7 +656,7 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -582,7 +656,7 @@ class RowParallelLinear(torch.nn.Module):
bias can be fused with other element-wise operations.
bias can be fused with other element-wise operations.
We skip adding bias but instead return it.
We skip adding bias but instead return it.
params_dtype: Data type for the parameters.
params_dtype: Data type for the parameters.
linear_method: (Maybe quantized) linear method
.
quant_config: Quantization configure
.
"""
"""
def
__init__
(
def
__init__
(
...
@@ -594,32 +668,26 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -594,32 +668,26 @@ class RowParallelLinear(torch.nn.Module):
skip_bias_add
:
bool
=
False
,
skip_bias_add
:
bool
=
False
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
params_dtype
:
Optional
[
torch
.
dtype
]
=
None
,
reduce_results
:
bool
=
True
,
reduce_results
:
bool
=
True
,
linear_method
:
Optional
[
LinearMethodBase
]
=
None
,
quant_config
:
Optional
[
QuantizationConfig
]
=
None
,
):
):
super
().
__init__
()
super
().
__init__
(
input_size
,
output_size
,
skip_bias_add
,
params_dtype
,
# Keep input parameters
quant_config
)
self
.
input_size
=
input_size
self
.
output_size
=
output_size
self
.
input_is_parallel
=
input_is_parallel
self
.
input_is_parallel
=
input_is_parallel
self
.
reduce_results
=
reduce_results
self
.
reduce_results
=
reduce_results
if
params_dtype
is
None
:
params_dtype
=
torch
.
get_default_dtype
()
self
.
params_dtype
=
params_dtype
# Divide the weight matrix along the last dimension.
# Divide the weight matrix along the last dimension.
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
tp_size
=
get_tensor_model_parallel_world_size
()
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
input_size_per_partition
=
divide
(
input_size
,
self
.
tp_size
)
self
.
skip_bias_add
=
skip_bias_add
# All the linear layer supports quant method.
if
linear_method
is
None
:
assert
self
.
quant_method
is
not
None
linear_method
=
UnquantizedLinearMethod
()
self
.
quant_method
.
create_weights
(
self
,
self
.
linear_method
=
linear_method
self
.
input_size_per_partition
,
self
.
linear_method
.
create_weights
(
self
,
[
self
.
output_size
],
self
.
input_size_per_partition
,
self
.
input_size
,
[
self
.
output_size
],
self
.
output_size
,
self
.
input_size
,
self
.
params_dtype
,
self
.
output_size
,
weight_loader
=
self
.
weight_loader
)
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
"When not reduce the results, adding bias to the "
raise
ValueError
(
"When not reduce the results, adding bias to the "
...
@@ -637,6 +705,10 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -637,6 +705,10 @@ class RowParallelLinear(torch.nn.Module):
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
self
.
use_llama_nn
=
os
.
environ
.
get
(
'LLAMA_NN'
)
==
'1'
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
def
weight_loader
(
self
,
param
:
Parameter
,
loaded_weight
:
torch
.
Tensor
):
# Special case for Fp8 scales.
fp8_scales_shard_indexer
=
getattr
(
param
,
"fp8_scales_shard_indexer"
,
None
)
tp_rank
=
get_tensor_model_parallel_rank
()
tp_rank
=
get_tensor_model_parallel_rank
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
param_data
=
param
.
data
param_data
=
param
.
data
...
@@ -645,6 +717,12 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -645,6 +717,12 @@ class RowParallelLinear(torch.nn.Module):
start_idx
=
tp_rank
*
shard_size
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
shard_size
)
# Special case for Fp8 scales.
elif
fp8_scales_shard_indexer
is
not
None
:
param_data
,
loaded_weight
=
fp8_scales_shard_indexer
(
param_data
,
loaded_weight
,
shard_id
=
0
)
assert
param_data
.
shape
==
loaded_weight
.
shape
assert
param_data
.
shape
==
loaded_weight
.
shape
if
self
.
use_llama_nn
:
if
self
.
use_llama_nn
:
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
loaded_weight
=
loaded_weight
.
transpose
(
0
,
1
)
...
@@ -662,8 +740,8 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -662,8 +740,8 @@ class RowParallelLinear(torch.nn.Module):
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
input_parallel
=
splitted_input
[
tp_rank
].
contiguous
()
# Matrix multiply.
# Matrix multiply.
output_parallel
=
self
.
linear_method
.
apply_weights
(
assert
self
.
quant_method
is
not
None
self
,
input_parallel
)
output_parallel
=
self
.
quant_method
.
apply
(
self
,
input_parallel
)
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
if
self
.
reduce_results
and
self
.
tp_size
>
1
:
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
output_
=
tensor_model_parallel_all_reduce
(
output_parallel
)
else
:
else
:
...
@@ -676,3 +754,11 @@ class RowParallelLinear(torch.nn.Module):
...
@@ -676,3 +754,11 @@ class RowParallelLinear(torch.nn.Module):
output
=
output_
output
=
output_
output_bias
=
self
.
bias
output_bias
=
self
.
bias
return
output
,
output_bias
return
output
,
output_bias
def
extra_repr
(
self
)
->
str
:
s
=
f
"input_features=
{
self
.
input_size_per_partition
}
"
s
+=
f
", output_features=
{
self
.
output_size
}
"
s
+=
f
", bias=
{
self
.
bias
is
not
None
}
"
s
+=
f
", tp_size=
{
self
.
tp_size
}
"
s
+=
f
", reduce_results=
{
self
.
reduce_results
}
"
return
s
vllm/model_executor/layers/logits_processor.py
View file @
1591c68f
...
@@ -70,6 +70,12 @@ class LogitsProcessor(nn.Module):
...
@@ -70,6 +70,12 @@ class LogitsProcessor(nn.Module):
logits
=
logits
[:,
:
self
.
org_vocab_size
]
logits
=
logits
[:,
:
self
.
org_vocab_size
]
return
logits
return
logits
def
extra_repr
(
self
)
->
str
:
s
=
f
"vocab_size=
{
self
.
vocab_size
}
"
s
+=
f
", forg_vocab_size=
{
self
.
org_vocab_size
}
"
s
+=
f
", scale=
{
self
.
scale
}
, logits_as_input=
{
self
.
logits_as_input
}
"
return
s
def
_prune_hidden_states
(
def
_prune_hidden_states
(
hidden_states
:
torch
.
Tensor
,
hidden_states
:
torch
.
Tensor
,
...
@@ -83,30 +89,27 @@ def _apply_logits_processors(
...
@@ -83,30 +89,27 @@ def _apply_logits_processors(
logits
:
torch
.
Tensor
,
logits
:
torch
.
Tensor
,
sampling_metadata
:
SamplingMetadata
,
sampling_metadata
:
SamplingMetadata
,
)
->
torch
.
Tensor
:
)
->
torch
.
Tensor
:
logits_row_idx
=
0
found_logits_processors
=
False
found_logits_processors
=
False
for
i
,
seq_group
in
enumerate
(
sampling_metadata
.
seq_groups
):
logits_processed
=
0
seq_ids
,
sampling_params
=
seq_group
for
seq_group
in
sampling_metadata
.
seq_groups
:
seq_ids
=
seq_group
.
seq_ids
sampling_params
=
seq_group
.
sampling_params
logits_processors
=
sampling_params
.
logits_processors
logits_processors
=
sampling_params
.
logits_processors
# handle prompt_logprobs by skipping rows in logits added for
# the prompt tokens (prompt logprobs are not processed)
if
(
i
<
sampling_metadata
.
num_prompts
and
sampling_params
.
prompt_logprobs
is
not
None
):
assert
len
(
seq_ids
)
==
1
logits_row_idx
+=
sampling_metadata
.
prompt_lens
[
i
]
-
1
if
logits_processors
:
if
logits_processors
:
found_logits_processors
=
True
found_logits_processors
=
True
for
seq_id
in
seq_ids
:
for
seq_id
,
logits_row_idx
in
zip
(
seq_ids
,
seq_group
.
sample_indices
):
logits_row
=
logits
[
logits_row_idx
]
logits_row
=
logits
[
logits_row_idx
]
token_ids
=
s
ampling_metadata
.
seq_data
[
seq_id
].
output_token_ids
token_ids
=
s
eq_group
.
seq_data
[
seq_id
].
output_token_ids
for
logits_processor
in
logits_processors
:
for
logits_processor
in
logits_processors
:
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits_row
=
logits_processor
(
token_ids
,
logits_row
)
logits
[
logits_row_idx
]
=
logits_row
logits
[
logits_row_idx
]
=
logits_row
logits_row_idx
+=
1
else
:
logits_processed
+=
len
(
seq_group
.
sample_indices
)
+
len
(
logits_row_idx
+=
len
(
seq_ids
)
seq_group
.
prompt_logprob_indices
)
if
found_logits_processors
:
if
found_logits_processors
:
# verifies that no rows in logits were missed unexpectedly
# verifies that no rows in logits were missed unexpectedly
assert
logits_ro
w_idx
==
logits
.
shape
[
0
]
assert
logits_
p
ro
cessed
==
logits
.
shape
[
0
]
return
logits
return
logits
vllm/model_executor/layers/quantization/__init__.py
View file @
1591c68f
from
typing
import
Type
from
typing
import
Dict
,
Type
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.aqlm
import
AQLMConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.awq
import
AWQConfig
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.fp8
import
F
P
8Config
from
vllm.model_executor.layers.quantization.fp8
import
F
p
8Config
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq
import
GPTQConfig
from
vllm.model_executor.layers.quantization.gptq_marlin
import
(
GPTQMarlinConfig
)
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.marlin
import
MarlinConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
from
vllm.model_executor.layers.quantization.squeezellm
import
SqueezeLLMConfig
QUANTIZATION_METHODS
=
{
QUANTIZATION_METHODS
:
Dict
[
str
,
Type
[
QuantizationConfig
]]
=
{
"aqlm"
:
AQLMConfig
,
"aqlm"
:
AQLMConfig
,
"awq"
:
AWQConfig
,
"awq"
:
AWQConfig
,
"fp8"
:
F
P
8Config
,
"fp8"
:
F
p
8Config
,
"gptq"
:
GPTQConfig
,
"gptq"
:
GPTQConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"squeezellm"
:
SqueezeLLMConfig
,
"gptq_marlin"
:
GPTQMarlinConfig
,
"marlin"
:
MarlinConfig
,
"marlin"
:
MarlinConfig
,
}
}
...
...
vllm/model_executor/layers/quantization/aqlm.py
View file @
1591c68f
...
@@ -8,11 +8,11 @@ import torch
...
@@ -8,11 +8,11 @@ import torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm._C
import
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
def
get_int_dtype
(
nbits
:
int
)
->
torch
.
dtype
:
def
get_int_dtype
(
nbits
:
int
)
->
torch
.
dtype
:
...
@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
...
@@ -207,8 +207,11 @@ class AQLMConfig(QuantizationConfig):
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
return
cls
(
in_group_size
,
nbits_per_codebook
,
num_code_books
,
out_group_size
)
out_group_size
)
def
get_linear_method
(
self
)
->
"AQLMLinearMethod"
:
def
get_quant_method
(
return
AQLMLinearMethod
(
self
)
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AQLMLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
AQLMLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[]
return
[]
...
@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
...
@@ -321,7 +324,7 @@ class AQLMLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
apply
_weights
(
def
apply
(
self
,
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
...
...
vllm/model_executor/layers/quantization/awq.py
View file @
1591c68f
...
@@ -4,10 +4,10 @@ import torch
...
@@ -4,10 +4,10 @@ import torch
from
torch.nn.parameter
import
Parameter
from
torch.nn.parameter
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.model_executor.layers.linear
import
(
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.utils
import
set_weight_attrs
class
AWQConfig
(
QuantizationConfig
):
class
AWQConfig
(
QuantizationConfig
):
...
@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
...
@@ -62,8 +62,11 @@ class AWQConfig(QuantizationConfig):
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
zero_point
=
cls
.
get_from_keys
(
config
,
[
"zero_point"
])
return
cls
(
weight_bits
,
group_size
,
zero_point
)
return
cls
(
weight_bits
,
group_size
,
zero_point
)
def
get_linear_method
(
self
)
->
"AWQLinearMethod"
:
def
get_quant_method
(
return
AWQLinearMethod
(
self
)
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
"AWQLinearMethod"
]:
if
isinstance
(
layer
,
LinearBase
):
return
AWQLinearMethod
(
self
)
return
None
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
def
get_scaled_act_names
(
self
)
->
List
[
str
]:
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
return
[
"gelu"
,
"gelu_fast"
,
"gelu_new"
,
"gelu_pytorch_tanh"
]
...
@@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase):
...
@@ -147,10 +150,10 @@ class AWQLinearMethod(LinearMethodBase):
layer
.
register_parameter
(
"scales"
,
scales
)
layer
.
register_parameter
(
"scales"
,
scales
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
set_weight_attrs
(
scales
,
extra_weight_attrs
)
def
apply
_weights
(
self
,
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
layer
:
torch
.
nn
.
Module
,
x
:
torch
.
Tensor
,
x
:
torch
.
Tensor
,
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
bias
:
Optional
[
torch
.
Tensor
]
=
None
)
->
torch
.
Tensor
:
qweight
=
layer
.
qweight
qweight
=
layer
.
qweight
scales
=
layer
.
scales
scales
=
layer
.
scales
qzeros
=
layer
.
qzeros
qzeros
=
layer
.
qzeros
...
...
vllm/model_executor/layers/quantization/base_config.py
View file @
1591c68f
from
abc
import
ABC
,
abstractmethod
from
abc
import
ABC
,
abstractmethod
from
typing
import
Any
,
Dict
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch
import
nn
from
vllm.model_executor.layers.linear
import
LinearMethodBase
class
QuantizeMethodBase
(
ABC
):
"""Base class for different quantized methods."""
@
abstractmethod
def
create_weights
(
self
,
layer
:
torch
.
nn
.
Module
,
*
weight_args
,
**
extra_weight_attrs
):
"""Create weights for a layer.
The weights will be set as attributes of the layer."""
raise
NotImplementedError
@
abstractmethod
def
apply
(
self
,
layer
:
torch
.
nn
.
Module
,
*
args
,
**
kwargs
)
->
torch
.
Tensor
:
"""Apply the weights in layer to the input tensor.
Expects create_weights to have been called before on the layer."""
raise
NotImplementedError
def
process_weights_after_loading
(
self
,
layer
:
nn
.
Module
)
->
None
:
"""Process the weight after loading.
This can be used for example, to transpose weights for computation.
"""
return
class
QuantizationConfig
(
ABC
):
class
QuantizationConfig
(
ABC
):
...
@@ -51,8 +76,16 @@ class QuantizationConfig(ABC):
...
@@ -51,8 +76,16 @@ class QuantizationConfig(ABC):
"quantization config."
)
"quantization config."
)
@
abstractmethod
@
abstractmethod
def
get_linear_method
(
self
)
->
LinearMethodBase
:
def
get_quant_method
(
"""Get the linear method to use for the quantized linear layer."""
self
,
layer
:
torch
.
nn
.
Module
)
->
Optional
[
QuantizeMethodBase
]:
"""Get the quantize method to use for the quantized layer.
Args:
layer: The layer for the quant method.
Returns:
The quantize method. None if the given layer doesn't support quant
method.
"""
raise
NotImplementedError
raise
NotImplementedError
@
abstractmethod
@
abstractmethod
...
...
Prev
1
…
6
7
8
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment