Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
fb377d7e
Unverified
Commit
fb377d7e
authored
Aug 13, 2024
by
Dipika Sikka
Committed by
GitHub
Aug 13, 2024
Browse files
[Misc] Update `gptq_marlin` to use new vLLMParameters (#7281)
parent
181abbc2
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
225 additions
and
89 deletions
+225
-89
.buildkite/test-pipeline.yaml
.buildkite/test-pipeline.yaml
+10
-0
tests/weight_loading/models.txt
tests/weight_loading/models.txt
+15
-0
tests/weight_loading/run_model_weight_loading_test.sh
tests/weight_loading/run_model_weight_loading_test.sh
+32
-0
tests/weight_loading/test_weight_loading.py
tests/weight_loading/test_weight_loading.py
+20
-0
vllm/model_executor/layers/linear.py
vllm/model_executor/layers/linear.py
+3
-1
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
...on/compressed_tensors/schemes/compressed_tensors_wNa16.py
+1
-1
vllm/model_executor/layers/quantization/gptq_marlin.py
vllm/model_executor/layers/quantization/gptq_marlin.py
+59
-62
vllm/model_executor/parameter.py
vllm/model_executor/parameter.py
+85
-25
No files found.
.buildkite/test-pipeline.yaml
View file @
fb377d7e
...
@@ -314,6 +314,16 @@ steps:
...
@@ -314,6 +314,16 @@ steps:
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
export VLLM_WORKER_MULTIPROC_METHOD=spawn
-
pytest -v -s -x lora/test_long_context.py
-
pytest -v -s -x lora/test_long_context.py
-
label
:
Weight Loading Multiple GPU Test
working_dir
:
"
/vllm-workspace/tests"
num_gpus
:
2
source_file_dependencies
:
-
vllm/
-
tests/weight_loading
commands
:
-
bash weight_loading/run_model_weight_loading_test.sh
##### multi gpus test #####
##### multi gpus test #####
##### A100 test #####
##### A100 test #####
...
...
tests/weight_loading/models.txt
0 → 100644
View file @
fb377d7e
gptq_marlin, robertgshaw2/zephyr-7b-beta-channelwise-gptq, main
gptq_marlin, TheBloke/Llama-2-7B-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, main
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit--1g-actorder_True
gptq_marlin, TheBloke/TinyLlama-1.1B-Chat-v1.0-GPTQ, gptq-8bit-32g-actorder_True
gptq_marlin, TechxGenus/gemma-1.1-2b-it-GPTQ, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8w8-test-static-shape-change, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8-channel-a8-tensor, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a8-channel-dynamic-token-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w4a16-group128-v2, main
compressed-tensors, nm-testing/tinyllama-oneshot-w8a16-per-channel, main
compressed-tensors, nm-testing/Meta-Llama-3-8B-FP8-compressed-tensors-test, main
compressed-tensors, nm-testing/Phi-3-mini-128k-instruct-FP8, main
compressed-tensors, neuralmagic/Phi-3-medium-128k-instruct-quantized.w4a16, main
\ No newline at end of file
tests/weight_loading/run_model_weight_loading_test.sh
0 → 100644
View file @
fb377d7e
#!/bin/bash
SUCCESS
=
0
IFS
=
$'
\n
'
read
-d
''
-r
-a
MODEL_CONFIGS <
"weight_loading/models.txt"
for
MODEL_CONFIG
in
"
${
MODEL_CONFIGS
[@]
}
"
do
LOCAL_SUCCESS
=
0
IFS
=
', '
read
-r
-a
array
<<<
"
$MODEL_CONFIG
"
echo
"=== RUNNING MODEL:
$MODEL_CONFIG
==="
export
QUANTIZATION
=
${
array
[0]
}
export
MODEL_NAME
=
${
array
[1]
}
export
REVISION
=
${
array
[2]
}
pytest
-s
weight_loading/test_weight_loading.py
||
LOCAL_SUCCESS
=
$?
if
[[
$LOCAL_SUCCESS
==
0
]]
;
then
echo
"=== PASSED MODEL:
${
MODEL_CONFIG
}
==="
else
echo
"=== FAILED MODEL:
${
MODEL_CONFIG
}
==="
fi
SUCCESS
=
$((
SUCCESS
+
LOCAL_SUCCESS
))
done
if
[
"
${
SUCCESS
}
"
-eq
"0"
]
;
then
exit
0
else
exit
1
fi
tests/weight_loading/test_weight_loading.py
0 → 100644
View file @
fb377d7e
import
os
MAX_MODEL_LEN
=
1024
MODEL_NAME
=
os
.
environ
.
get
(
"MODEL_NAME"
,
"robertgshaw2/zephyr-7b-beta-channelwise-gptq"
)
REVISION
=
os
.
environ
.
get
(
"REVISION"
,
"main"
)
QUANTIZATION
=
os
.
environ
.
get
(
"QUANTIZATION"
,
"gptq_marlin"
)
def
test_weight_loading
(
vllm_runner
):
with
vllm_runner
(
model_name
=
MODEL_NAME
,
revision
=
REVISION
,
dtype
=
"auto"
,
quantization
=
QUANTIZATION
,
max_model_len
=
MAX_MODEL_LEN
,
tensor_parallel_size
=
2
)
as
model
:
output
=
model
.
generate_greedy
(
"Hello world!"
,
max_tokens
=
20
)
print
(
output
)
assert
output
vllm/model_executor/layers/linear.py
View file @
fb377d7e
...
@@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
...
@@ -20,7 +20,9 @@ from vllm.model_executor.utils import set_weight_attrs
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
]
WEIGHT_LOADER_V2_SUPPORTED
=
[
"CompressedTensorsLinearMethod"
,
"GPTQMarlinLinearMethod"
]
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
def
adjust_marlin_shard
(
param
,
shard_size
,
shard_offset
):
...
...
vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py
View file @
fb377d7e
...
@@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
...
@@ -105,7 +105,7 @@ class CompressedTensorsWNA16(CompressedTensorsScheme):
dtype
=
params_dtype
,
dtype
=
params_dtype
,
)
)
}
}
if
self
.
group_size
==
-
1
:
if
not
partition_scales
:
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
weight_scale
=
ChannelQuantScaleParameter
(
output_dim
=
0
,
**
weight_scale_args
)
**
weight_scale_args
)
else
:
else
:
...
...
vllm/model_executor/layers/quantization/gptq_marlin.py
View file @
fb377d7e
from
typing
import
Any
,
Dict
,
List
,
Optional
from
typing
import
Any
,
Dict
,
List
,
Optional
import
torch
import
torch
from
torch.nn
.parameter
import
Parameter
from
torch.nn
import
Parameter
from
vllm
import
_custom_ops
as
ops
from
vllm
import
_custom_ops
as
ops
from
vllm.logger
import
init_logger
from
vllm.logger
import
init_logger
from
vllm.model_executor.layers.linear
import
(
LinearBase
,
LinearMethodBase
,
from
vllm.model_executor.layers.linear
import
LinearBase
,
LinearMethodBase
set_weight_attrs
)
from
vllm.model_executor.layers.quantization.base_config
import
(
from
vllm.model_executor.layers.quantization.base_config
import
(
QuantizationConfig
)
QuantizationConfig
)
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
from
vllm.model_executor.layers.quantization.utils.marlin_utils
import
(
...
@@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
...
@@ -15,6 +14,11 @@ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
marlin_repeat_scales_on_all_ranks
,
marlin_sort_g_idx
,
replace_tensor
,
verify_marlin_supported
,
verify_marlin_supports_shape
)
verify_marlin_supported
,
verify_marlin_supports_shape
)
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.layers.vocab_parallel_embedding
import
ParallelLMHead
from
vllm.model_executor.parameter
import
(
ChannelQuantScaleParameter
,
GroupQuantScaleParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
RowvLLMParameter
)
from
vllm.scalar_type
import
scalar_types
from
vllm.scalar_type
import
scalar_types
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -159,9 +163,11 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
params_dtype
:
torch
.
dtype
,
params_dtype
:
torch
.
dtype
,
**
extra_weight_attrs
,
**
extra_weight_attrs
,
)
->
None
:
)
->
None
:
del
output_size
del
output_size
output_size_per_partition
=
sum
(
output_partition_sizes
)
output_size_per_partition
=
sum
(
output_partition_sizes
)
is_row_parallel
=
input_size
!=
input_size_per_partition
is_row_parallel
=
input_size
!=
input_size_per_partition
weight_loader
=
extra_weight_attrs
.
get
(
"weight_loader"
)
# Normalize group_size
# Normalize group_size
if
self
.
quant_config
.
group_size
!=
-
1
:
if
self
.
quant_config
.
group_size
!=
-
1
:
...
@@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -190,79 +196,66 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
scales_and_zp_size
=
input_size_per_partition
//
group_size
scales_and_zp_size
=
input_size_per_partition
//
group_size
# Quantized weights
# Quantized weights
qweight
=
Parameter
(
qweight
=
PackedvLLM
Parameter
(
torch
.
empty
(
data
=
torch
.
empty
(
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
input_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
,
output_size_per_partition
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
output_dim
=
1
,
set_weight_attrs
(
packed_dim
=
0
,
qweight
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
{
weight_loader
=
weight_loader
)
**
extra_weight_attrs
,
"input_dim"
:
0
,
"output_dim"
:
1
,
"packed_dim"
:
0
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
},
)
# Activation order
# Activation order
g_idx
=
Parameter
(
g_idx
=
RowvLLMParameter
(
data
=
torch
.
empty
(
torch
.
empty
(
input_size_per_partition
,
input_size_per_partition
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
input_dim
=
0
,
)
weight_loader
=
weight_loader
)
# Ignore warning from fused linear layers such as QKVParallelLinear.
set_weight_attrs
(
g_idx
,
{
**
extra_weight_attrs
,
"input_dim"
:
0
,
"ignore_warning"
:
True
},
)
# Scales
scales
=
Parameter
(
torch
.
empty
(
scales_and_zp_size
,
output_size_per_partition
,
dtype
=
params_dtype
,
),
requires_grad
=
False
,
)
set_weight_attrs
(
scales
,
{
**
extra_weight_attrs
,
"input_dim"
:
scales_and_zp_input_dim
,
"output_dim"
:
1
,
},
)
# Quantized zero-points
qzeros_args
=
{
qzeros
=
Parameter
(
"data"
:
torch
.
empty
(
torch
.
empty
(
scales_and_zp_size
,
scales_and_zp_size
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
output_size_per_partition
//
self
.
quant_config
.
pack_factor
,
dtype
=
torch
.
int32
,
dtype
=
torch
.
int32
,
),
),
requires_grad
=
False
,
"weight_loader"
:
)
weight_loader
set_weight_attrs
(
}
qzeros
,
weight_scale_args
=
{
{
"data"
:
**
extra_weight_attrs
,
torch
.
empty
(
"input_dim"
:
scales_and_zp_input_dim
,
scales_and_zp_size
,
"output_dim"
:
1
,
output_size_per_partition
,
"packed_dim"
:
1
,
dtype
=
params_dtype
,
"pack_factor"
:
self
.
quant_config
.
pack_factor
,
),
},
"weight_loader"
:
)
weight_loader
}
if
scales_and_zp_input_dim
is
None
:
scales
=
ChannelQuantScaleParameter
(
output_dim
=
1
,
**
weight_scale_args
)
qzeros
=
PackedColumnParameter
(
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
else
:
scales
=
GroupQuantScaleParameter
(
output_dim
=
1
,
input_dim
=
0
,
**
weight_scale_args
)
qzeros
=
PackedvLLMParameter
(
input_dim
=
0
,
output_dim
=
1
,
packed_dim
=
1
,
packed_factor
=
self
.
quant_config
.
pack_factor
,
**
qzeros_args
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"qweight"
,
qweight
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
layer
.
register_parameter
(
"g_idx"
,
g_idx
)
...
@@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
...
@@ -280,6 +273,10 @@ class GPTQMarlinLinearMethod(LinearMethodBase):
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
def
process_weights_after_loading
(
self
,
layer
:
torch
.
nn
.
Module
)
->
None
:
device
=
layer
.
qweight
.
device
device
=
layer
.
qweight
.
device
# required by torch.compile
layer
.
qweight
=
Parameter
(
layer
.
qweight
.
data
,
requires_grad
=
False
)
layer
.
scales
=
Parameter
(
layer
.
scales
.
data
,
requires_grad
=
False
)
# Allocate marlin workspace
# Allocate marlin workspace
layer
.
workspace
=
marlin_make_workspace
(
layer
.
workspace
=
marlin_make_workspace
(
layer
.
output_size_per_partition
,
device
)
layer
.
output_size_per_partition
,
device
)
...
...
vllm/model_executor/parameter.py
View file @
fb377d7e
...
@@ -9,7 +9,7 @@ from vllm.logger import init_logger
...
@@ -9,7 +9,7 @@ from vllm.logger import init_logger
__all__
=
[
__all__
=
[
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
"BasevLLMParameter"
,
"PackedvLLMParameter"
,
"PerTensorScaleParameter"
,
"ModelWeightParameter"
,
"ChannelQuantScaleParameter"
,
"ModelWeightParameter"
,
"ChannelQuantScaleParameter"
,
"GroupQuantScaleParameter"
"GroupQuantScaleParameter"
,
"PackedColumnParameter"
,
"RowvLLMParameter"
]
]
logger
=
init_logger
(
__name__
)
logger
=
init_logger
(
__name__
)
...
@@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -92,7 +92,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
shard_size
=
kwargs
.
get
(
"shard_size"
)
shard_size
=
kwargs
.
get
(
"shard_size"
)
if
isinstance
(
if
isinstance
(
self
,
self
,
PackedvLLMParameter
)
and
self
.
packed_dim
==
self
.
output_dim
:
(
PackedColumnParameter
,
PackedvLLMParameter
))
and
self
.
packed_dim
==
self
.
output_dim
:
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
...
@@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -115,7 +116,8 @@ class _ColumnvLLMParameter(BasevLLMParameter):
if
isinstance
(
if
isinstance
(
self
,
self
,
PackedvLLMParameter
)
and
self
.
output_dim
==
self
.
packed_dim
:
(
PackedColumnParameter
,
PackedvLLMParameter
))
and
self
.
output_dim
==
self
.
packed_dim
:
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
=
self
.
adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
shard_offset
=
shard_offset
,
shard_size
=
shard_size
)
...
@@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
...
@@ -131,12 +133,12 @@ class _ColumnvLLMParameter(BasevLLMParameter):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
class
ModelWeight
Parameter
(
_Column
vLLMParameter
):
class
RowvLLM
Parameter
(
Base
vLLMParameter
):
"""
"""
Parameter class
for linear layer weights. Extends the
Parameter class
defining weight_loading functionality
_ColumnvLLMP
arameter b
y add
ing load
ing functionality
(load_row_parallel_weight) for p
arameter
s
b
e
ing load
ed
for
linear layers with row parallel functionality.
into
linear layers with row parallel functionality.
Requires an input
dim
ension
to be defined.
Requires an input
_
dim to be defined.
"""
"""
def
__init__
(
self
,
input_dim
:
int
,
**
kwargs
):
def
__init__
(
self
,
input_dim
:
int
,
**
kwargs
):
...
@@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
...
@@ -160,10 +162,18 @@ class ModelWeightParameter(_ColumnvLLMParameter):
self
.
data
.
copy_
(
loaded_weight
)
self
.
data
.
copy_
(
loaded_weight
)
class
GroupQuantScaleParameter
(
ModelWeightParameter
):
class
ModelWeightParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
Parameter class for linear layer weights. Uses both column and
row parallelism.
"""
pass
class
GroupQuantScaleParameter
(
_ColumnvLLMParameter
,
RowvLLMParameter
):
"""
"""
Parameter class for weight scales loaded for weights with
Parameter class for weight scales loaded for weights with
grouped quantization.
Equivalent to ModelWeightParameter
.
grouped quantization.
Uses both column and row parallelism
.
"""
"""
pass
pass
...
@@ -232,6 +242,43 @@ class PerTensorScaleParameter(BasevLLMParameter):
...
@@ -232,6 +242,43 @@ class PerTensorScaleParameter(BasevLLMParameter):
param_data
.
copy_
(
loaded_weight
)
param_data
.
copy_
(
loaded_weight
)
class
PackedColumnParameter
(
_ColumnvLLMParameter
):
"""
Parameter for model parameters which are packed on disk
and support column parallelism only. See PackedvLLMParameter
for more details on the packed properties.
"""
def
__init__
(
self
,
packed_factor
:
int
,
packed_dim
:
int
,
marlin_tile_size
:
Optional
[
int
]
=
None
,
**
kwargs
):
self
.
_packed_factor
=
packed_factor
self
.
_packed_dim
=
packed_dim
self
.
_marlin_tile_size
=
marlin_tile_size
super
().
__init__
(
**
kwargs
)
@
property
def
packed_dim
(
self
):
return
self
.
_packed_dim
@
property
def
packed_factor
(
self
):
return
self
.
_packed_factor
@
property
def
marlin_tile_size
(
self
):
return
self
.
_marlin_tile_size
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
):
return
_adjust_shard_indexes_for_packing
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
,
packed_factor
=
self
.
packed_factor
,
marlin_tile_size
=
self
.
marlin_tile_size
)
class
PackedvLLMParameter
(
ModelWeightParameter
):
class
PackedvLLMParameter
(
ModelWeightParameter
):
"""
"""
Parameter for model weights which are packed on disk.
Parameter for model weights which are packed on disk.
...
@@ -250,7 +297,7 @@ class PackedvLLMParameter(ModelWeightParameter):
...
@@ -250,7 +297,7 @@ class PackedvLLMParameter(ModelWeightParameter):
**
kwargs
):
**
kwargs
):
self
.
_packed_factor
=
packed_factor
self
.
_packed_factor
=
packed_factor
self
.
_packed_dim
=
packed_dim
self
.
_packed_dim
=
packed_dim
self
.
_marlin_tile
=
marlin_tile_size
self
.
_marlin_tile
_size
=
marlin_tile_size
super
().
__init__
(
**
kwargs
)
super
().
__init__
(
**
kwargs
)
@
property
@
property
...
@@ -262,16 +309,29 @@ class PackedvLLMParameter(ModelWeightParameter):
...
@@ -262,16 +309,29 @@ class PackedvLLMParameter(ModelWeightParameter):
return
self
.
_packed_factor
return
self
.
_packed_factor
@
property
@
property
def
marlin_tile
(
self
):
def
marlin_tile_size
(
self
):
return
self
.
_marlin_tile
return
self
.
_marlin_tile_size
def
_adjust_shard_indexes_for_marlin
(
self
,
shard_size
,
shard_offset
):
return
shard_size
*
self
.
marlin_tile
,
shard_offset
*
self
.
marlin_tile
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
):
def
adjust_shard_indexes_for_packing
(
self
,
shard_size
,
shard_offset
):
shard_size
=
shard_size
//
self
.
packed_factor
return
_adjust_shard_indexes_for_packing
(
shard_offset
=
shard_offset
//
self
.
packed_factor
shard_size
=
shard_size
,
if
self
.
marlin_tile
is
not
None
:
shard_offset
=
shard_offset
,
return
self
.
_adjust_shard_indexes_for_marlin
(
packed_factor
=
self
.
packed_factor
,
shard_size
,
shard_offset
)
marlin_tile_size
=
self
.
marlin_tile_size
)
def
_adjust_shard_indexes_for_marlin
(
shard_size
,
shard_offset
,
marlin_tile_size
):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
_adjust_shard_indexes_for_packing
(
shard_size
,
shard_offset
,
packed_factor
,
marlin_tile_size
):
shard_size
=
shard_size
//
packed_factor
shard_offset
=
shard_offset
//
packed_factor
if
marlin_tile_size
is
not
None
:
return
_adjust_shard_indexes_for_marlin
(
shard_size
=
shard_size
,
shard_offset
=
shard_offset
,
marlin_tile_size
=
marlin_tile_size
)
return
shard_size
,
shard_offset
return
shard_size
,
shard_offset
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment