Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb377d7e
Unverified
Commit
fb377d7e
authored
Aug 13, 2024
by
Dipika Sikka
Committed by
GitHub
Aug 13, 2024
Browse files
[Misc] Update `gptq_marlin` to use new vLLMParameters (#7281)
parent
181abbc2
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
225 additions
and
89 deletions
+225
-89
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+10
-0
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+15
-0
tests/weight_loading/run_model_weight_loading_test.sh
tests/weight_loading/run_model_weight_loading_test.sh
+32
-0
tests/weight_loading/test_weight_loading.py
tests/weight_loading/test_weight_loading.py
+20
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+3
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+59
-62
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+85
-25
No files found.
.buildkite/test-pipeline.yaml
View file @
fb377d7e
...
@@ -314,6 +314,16 @@ steps:
...
@@ -314,6 +314,16 @@ steps:
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
pytest -v -s -x lora/test_long_context.py
-
pytest -v -s -x lora/test_long_context.py
-
label
:
Weight Loading Multiple GPU Test
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
2
source_file_dependencies
:
-
vllm/
-
tests/weight_loading
commands
:
-
bash weight_loading/run_model_weight_loading_test.sh
##### multi gpus test #####
##### multi gpus test #####
##### A100 test #####
##### A100 test #####
...
...
tests/weight_loading/models.txt
0 → 100644
View file @
fb377d7e
gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
\ No newline at end of file
tests/weight_loading/run_model_weight_loading_test.sh
0 → 100644
View file @
fb377d7e
#!/bin/bash
SUCCESS
=
0
IFS
=
$'
\n
'
read
-d
''
-r
-a
MODEL_CONFIGS <
"weight_loading/models.txt"
for
MODEL_CONFIG
in
"
${
MODEL_CONFIGS
[@]
}
"
do
LOCAL_SUCCESS
=
0
IFS
=
', '
read
-r
-a
array
<<<
"
$MODEL_CONFIG
"
echo
"=== RUNNING MODEL:
$MODEL_CONFIG
==="
export
QUANTIZATION
=
${
array
[0]
}
export
MODEL_NAME
=
${
array
[1]
}
export
REVISION
=
${
array
[2]
}
pytest
-s
weight_loading/test_weight_loading.py
||
LOCAL_SUCCESS
=
$?
if
[[
$LOCAL_SUCCESS
==
0
]]
;
then
echo
"=== PASSED MODEL:
${
MODEL_CONFIG
}
==="
else
echo
"=== FAILED MODEL:
${
MODEL_CONFIG
}
==="
fi
SUCCESS
=
$((
SUCCESS
+
LOCAL_SUCCESS
))
done
if
[
"
${
SUCCESS
}
"
-eq
"0"
]
;
then
exit
0
else
exit
1
fi
tests/weight_loading/test_weight_loading.py
0 → 100644
View file @
fb377d7e
import
os
MAX_MODEL_LEN
=
1024
MODEL_NAME
=
os
.
environ
.
get
(
"MODEL_NAME"
,
"robertgshaw2/zephyr-7b-beta-channelwise-gptq"
)
REVISION
=
os
.
environ
.
get
(
"REVISION"
,
"main"
)
QUANTIZATION
=
os
.
environ
.
get
(
"QUANTIZATION"
,
"gptq_marlin"
)
def
test_weight_loading
(
vllm_runner
):
with
vllm_runner
(
model_name
=
MODEL_NAME
,
revision
=
REVISION
,
dtype
=
"auto"
,
quantization
=
QUANTIZATION
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
2
)
as
model
:
output
=
model
.
generate_greedy
(
"Hello world!"
,
max_tokens
=
20
)
print
(
output
)
assert
output
vllm/model_executor/layers/linear.py
View file @
fb377d7e
...
@@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
]
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"GPTQMarlinLinearMethod"
]
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
fb377d7e
...
@@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype
=
params_dtype
,
dtype
=
params_dtype
,
)
)
}
}
if
self
.
group_size
==
-
1
:
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
**
weight_scale_args
)
else
:
else
:
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
fb377d7e
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch.nn
.parameter
import
Parameter
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
@@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
)
->
None
:
)
->
None
:
del
output_size
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
is_row_parallel
=
input_size
!=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
...
@@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_size
=
input_size_per_partition
//
group_size
# Quantized weights
# Quantized weights
qweight
=
Parameter
(
qweight
=
PackedvLLM
Parameter
(
torch
.
empty
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
packed_dim
=
0
,
qweight
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
{
weight_loader
=
weight_loader
)
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
# Activation order
g_idx
=
Parameter
(
g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
input_size_per_partition
,
dtype
=
torch
.
int32
,
),
input_dim
=
0
,
weight_loader
=
weight_loader
)
qzeros_args
=
{
"data"
:
torch
.
empty
(
torch
.
empty
(
input_size_per_partition
,
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
"weight_loader"
:
)
weight_loader
# Ignore warning from fused linear layers such as QKVParallelLinear.
}
set_weight_attrs
(
weight_scale_args
=
{
g_idx
,
"data"
:
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
torch
.
empty
(
scales_and_zp_size
,
scales_and_zp_size
,
output_size_per_partition
,
output_size_per_partition
,
dtype
=
params_dtype
,
dtype
=
params_dtype
,
),
),
requires_grad
=
False
,
"weight_loader"
:
)
weight_loader
set_weight_attrs
(
}
scales
,
{
if
scales_and_zp_input_dim
is
None
:
**
extra_weight_attrs
,
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
"input_dim"
:
scales_and_zp_input_dim
,
**
weight_scale_args
)
"output_dim"
:
1
,
qzeros
=
PackedColumnParameter
(
},
output_dim
=
1
,
)
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
# Quantized zero-points
**
qzeros_args
)
qzeros
=
Parameter
(
torch
.
empty
(
else
:
scales_and_zp_size
,
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
input_dim
=
0
,
dtype
=
torch
.
int32
,
**
weight_scale_args
)
),
qzeros
=
PackedvLLMParameter
(
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
packed_dim
=
1
,
qzeros
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
{
**
qzeros_args
)
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
"packed_dim"
:
1
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
...
@@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
device
=
layer
.
qweight
.
device
# required by torch.compile
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
scales
=
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# Allocate marlin workspace
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
layer
.
output_size_per_partition
,
device
)
...
...
vllm/model_executor/parameter.py
View file @
fb377d7e
...
@@ -9,7 +9,7 @@ from vllm.logger import init_logger
...
@@ -9,7 +9,7 @@ from vllm.logger import init_logger
__all__
=
[
__all__
=
[
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
"ModelWeightParameter"
,
"ChannelQuantScaleParameter"
,
"ModelWeightParameter"
,
"ChannelQuantScaleParameter"
,
"GroupQuantScaleParameter"
"GroupQuantScaleParameter"
,
"PackedColumnParameter"
,
"RowvLLMParameter"
]
]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size
=
kwargs
.
get
(
"shard_size"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
if
isinstance
(
if
isinstance
(
self
,
self
,
PackedvLLMParameter
)
and
self
.
packed_dim
==
self
.
output_dim
:
(
PackedColumnParameter
,
PackedvLLMParameter
))
and
self
.
packed_dim
==
self
.
output_dim
:
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
...
@@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
if
isinstance
(
if
isinstance
(
self
,
self
,
PackedvLLMParameter
)
and
self
.
output_dim
==
self
.
packed_dim
:
(
PackedColumnParameter
,
PackedvLLMParameter
))
and
self
.
output_dim
==
self
.
packed_dim
:
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
...
@@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
class
ModelWeight
Parameter
(
_Column
vLLMParameter
):
class
RowvLLM
Parameter
(
Base
vLLMParameter
):
"""
"""
Parameter class
for linear layer weights. Extends the
Parameter class
defining weight_loading functionality
_ColumnvLLMP
arameter b
y add
ing load
ing functionality
(load_row_parallel_weight) for p
arameter
s
b
e
ing load
ed
for
linear layers with row parallel functionality.
into
linear layers with row parallel functionality.
Requires an input
dim
ension
to be defined.
Requires an input
_
dim to be defined.
"""
"""
def
__init__
(
self
,
input_dim
:
int
,
**
kwargs
):
def
__init__
(
self
,
input_dim
:
int
,
**
kwargs
):
...
@@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
...
@@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
self
.
data
.
copy_
(
loaded_weight
)
self
.
data
.
copy_
(
loaded_weight
)
class
GroupQuantScaleParameter
(
ModelWeightParameter
):
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class
GroupQuantScaleParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
"""
Parameter class for weight scales loaded for weights with
Parameter class for weight scales loaded for weights with
grouped quantization.
Equivalent to ModelWeightParameter
.
grouped quantization.
Uses both column and row parallelism
.
"""
"""
pass
pass
...
@@ -171,7 +181,7 @@ class GroupQuantScaleParameter(ModelWeightParameter):
...
@@ -171,7 +181,7 @@ class GroupQuantScaleParameter(ModelWeightParameter):
class
ChannelQuantScaleParameter
(
_ColumnvLLMParameter
):
class
ChannelQuantScaleParameter
(
_ColumnvLLMParameter
):
"""
"""
Parameter class for weight scales loaded for weights with
Parameter class for weight scales loaded for weights with
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
channel-wise quantization. Equivalent to _ColumnvLLMParameter.
"""
"""
pass
pass
...
@@ -181,7 +191,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
...
@@ -181,7 +191,7 @@ class PerTensorScaleParameter(BasevLLMParameter):
Parameter class for scales where the number of scales is
Parameter class for scales where the number of scales is
equivalent to the number of logical matrices in fused linear
equivalent to the number of logical matrices in fused linear
layers (e.g. for QKV, there are 3 scales loaded from disk).
layers (e.g. for QKV, there are 3 scales loaded from disk).
This is relevant to weights with per-tensor quantization.
This is relevant to weights with per-tensor quantization.
Adds functionality to map the scalers to a shard during
Adds functionality to map the scalers to a shard during
weight loading.
weight loading.
...
@@ -232,6 +242,43 @@ class PerTensorScaleParameter(BasevLLMParameter):
...
@@ -232,6 +242,43 @@ class PerTensorScaleParameter(BasevLLMParameter):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
class
PackedColumnParameter
(
_ColumnvLLMParameter
):
"""
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def
__init__
(
self
,
packed_factor
:
int
,
packed_dim
:
int
,
marlin_tile_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
self
.
_packed_factor
=
packed_factor
self
.
_packed_dim
=
packed_dim
self
.
_marlin_tile_size
=
marlin_tile_size
super
().
__init__
(
**
kwargs
)
@
property
def
packed_dim
(
self
):
return
self
.
_packed_dim
@
property
def
packed_factor
(
self
):
return
self
.
_packed_factor
@
property
def
marlin_tile_size
(
self
):
return
self
.
_marlin_tile_size
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
):
return
_adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
,
packed_factor
=
self
.
packed_factor
,
marlin_tile_size
=
self
.
marlin_tile_size
)
class
PackedvLLMParameter
(
ModelWeightParameter
):
class
PackedvLLMParameter
(
ModelWeightParameter
):
"""
"""
Parameter for model weights which are packed on disk.
Parameter for model weights which are packed on disk.
...
@@ -239,7 +286,7 @@ class PackedvLLMParameter(ModelWeightParameter):
...
@@ -239,7 +286,7 @@ class PackedvLLMParameter(ModelWeightParameter):
Extends the ModelWeightParameter to take in the
Extends the ModelWeightParameter to take in the
packed factor, the packed dimension, and optionally, marlin
packed factor, the packed dimension, and optionally, marlin
tile size for marlin kernels. Adjusts the shard_size and
tile size for marlin kernels. Adjusts the shard_size and
shard_offset for fused linear layers model weight loading
shard_offset for fused linear layers model weight loading
by accounting for packing and optionally, marlin tile size.
by accounting for packing and optionally, marlin tile size.
"""
"""
...
@@ -250,7 +297,7 @@ class PackedvLLMParameter(ModelWeightParameter):
...
@@ -250,7 +297,7 @@ class PackedvLLMParameter(ModelWeightParameter):
**
kwargs
):
**
kwargs
):
self
.
_packed_factor
=
packed_factor
self
.
_packed_factor
=
packed_factor
self
.
_packed_dim
=
packed_dim
self
.
_packed_dim
=
packed_dim
self
.
_marlin_tile
=
marlin_tile_size
self
.
_marlin_tile
_size
=
marlin_tile_size
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
@
property
@
property
...
@@ -262,16 +309,29 @@ class PackedvLLMParameter(ModelWeightParameter):
...
@@ -262,16 +309,29 @@ class PackedvLLMParameter(ModelWeightParameter):
return
self
.
_packed_factor
return
self
.
_packed_factor
@
property
@
property
def
marlin_tile
(
self
):
def
marlin_tile_size
(
self
):
return
self
.
_marlin_tile
return
self
.
_marlin_tile_size
def
_adjust_shard_indexes_for_marlin
(
self
,
shard_size
,
shard_offset
):
return
shard_size
*
self
.
marlin_tile
,
shard_offset
*
self
.
marlin_tile
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
):
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
):
shard_size
=
shard_size
//
self
.
packed_factor
return
_adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
//
self
.
packed_factor
shard_size
=
shard_size
,
if
self
.
marlin_tile
is
not
None
:
shard_offset
=
shard_offset
,
return
self
.
_adjust_shard_indexes_for_marlin
(
packed_factor
=
self
.
packed_factor
,
shard_size
,
shard_offset
)
marlin_tile_size
=
self
.
marlin_tile_size
)
return
shard_size
,
shard_offset
def
_adjust_shard_indexes_for_marlin
(
shard_size
,
shard_offset
,
marlin_tile_size
):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
_adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
,
packed_factor
,
marlin_tile_size
):
shard_size
=
shard_size
//
packed_factor
shard_offset
=
shard_offset
//
packed_factor
if
marlin_tile_size
is
not
None
:
return
_adjust_shard_indexes_for_marlin
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
,
marlin_tile_size
=
marlin_tile_size
)
return
shard_size
,
shard_offset
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment