Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
8bee20f8
Unverified
Commit
8bee20f8
authored
Oct 19, 2024
by
Yineng Zhang
Committed by
GitHub
Oct 19, 2024
Browse files
Update vllm to 0.6.3 (#1711) (#1720)
Co-authored-by:
Ke Bao
<
ISPObaoke@163.com
>
parent
12cad0fe
Changes
9
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
133 additions
and
76 deletions
+133
-76
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
...rontend_language/usage/llava_video/srt_example_llava_v.py
+1
-1
python/pyproject.toml
python/pyproject.toml
+1
-1
python/sglang/launch_server_llavavid.py
python/sglang/launch_server_llavavid.py
+1
-1
python/sglang/srt/layers/linear.py
python/sglang/srt/layers/linear.py
+89
-63
python/sglang/srt/lora/lora.py
python/sglang/srt/lora/lora.py
+3
-1
python/sglang/srt/model_executor/model_runner.py
python/sglang/srt/model_executor/model_runner.py
+14
-5
python/sglang/srt/models/deepseek_v2.py
python/sglang/srt/models/deepseek_v2.py
+2
-2
python/sglang/srt/utils.py
python/sglang/srt/utils.py
+20
-0
test/srt/test_server_args.py
test/srt/test_server_args.py
+2
-2
No files found.
examples/frontend_language/usage/llava_video/srt_example_llava_v.py
View file @
8bee20f8
...
...
@@ -208,7 +208,7 @@ if __name__ == "__main__":
model_override_args
[
"image_token_index"
]
=
64002
if
args
.
num_frames
==
32
:
model_override_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_override_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"
rope_
type"
:
"linear"
}
model_override_args
[
"max_sequence_length"
]
=
4096
*
2
model_override_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
elif
args
.
num_frames
<
32
:
...
...
python/pyproject.toml
View file @
8bee20f8
...
...
@@ -26,7 +26,7 @@ runtime_common = ["aiohttp", "decord", "fastapi", "hf_transfer", "huggingface_hu
"outlines>=0.0.44"
,
"modelscope"
]
# xpu is not enabled in public vllm and torch whl,
# need to follow https://docs.vllm.ai/en/latest/getting_started/xpu-installation.htmlinstall vllm
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm==
0.
5.5
"]
srt
=
["sglang[runtime_common]
", "
torch
", "
vllm==
0.
6.3
.post
1
"]
srt_xpu
=
["sglang[runtime_common]"]
openai
=
[
"openai>=1.0"
,
"tiktoken"
]
...
...
python/sglang/launch_server_llavavid.py
View file @
8bee20f8
...
...
@@ -14,7 +14,7 @@ if __name__ == "__main__":
model_override_args
[
"num_frames"
]
=
16
model_override_args
[
"model_type"
]
=
"llavavid"
if
model_override_args
[
"num_frames"
]
==
32
:
model_override_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"type"
:
"linear"
}
model_override_args
[
"rope_scaling"
]
=
{
"factor"
:
2.0
,
"
rope_
type"
:
"linear"
}
model_override_args
[
"max_sequence_length"
]
=
4096
*
2
model_override_args
[
"tokenizer_model_max_length"
]
=
4096
*
2
model_override_args
[
"model_max_length"
]
=
4096
*
2
...
...
python/sglang/srt/layers/linear.py
View file @
8bee20f8
...
...
@@ -20,8 +20,10 @@ from vllm.distributed import (
from
vllm.model_executor.layers.linear
import
LinearBase
from
vllm.model_executor.parameter
import
(
BasevLLMParameter
,
PackedColumnParameter
,
PackedvLLMParameter
,
PerTensorScaleParameter
,
RowvLLMParameter
,
)
from
sglang.srt.layers.quantization.base_config
import
(
...
...
@@ -39,6 +41,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQMarlinLinearMethod"
,
"Fp8LinearMethod"
,
"MarlinLinearMethod"
,
"GPTQLinearMethod"
,
]
...
...
@@ -50,7 +53,7 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
return
shard_size
*
marlin_tile_size
,
shard_offset
*
marlin_tile_size
def
adjust_bitsandbytes_shard
(
def
adjust_bitsandbytes_
4bit_
shard
(
param
:
Parameter
,
qkv_offsets
:
Dict
[
str
,
Tuple
[
int
,
int
]],
loaded_shard_id
:
str
)
->
Tuple
[
int
,
int
]:
"""Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
...
...
@@ -207,7 +210,6 @@ class ReplicatedLinear(LinearBase):
self
.
output_size
,
self
.
params_dtype
,
weight_loader
=
self
.
weight_loader
,
prefix
=
prefix
,
)
if
bias
:
...
...
@@ -315,7 +317,6 @@ class ColumnParallelLinear(LinearBase):
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
prefix
=
prefix
,
)
if
bias
:
self
.
bias
=
Parameter
(
...
...
@@ -345,8 +346,12 @@ class ColumnParallelLinear(LinearBase):
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
param
.
materialize
(
loaded_weight
.
shape
,
dtype
=
loaded_weight
.
dtype
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
param_data
=
param
.
data
if
output_dim
is
not
None
:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
output_dim
is
not
None
and
not
use_bitsandbytes_4bit
:
shard_size
=
param_data
.
shape
[
output_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
...
...
@@ -454,17 +459,22 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_size
=
loaded_weight
.
size
(
output_dim
)
//
tp_size
start_idx
=
tp_rank
*
shard_size
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
)
)
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
2
:
self
.
qweight
=
param
.
materialize_nested
(
)
return
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
...
...
@@ -526,26 +536,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
use_bitsandbytes
_4bit
=
getattr
(
param
,
"use_bitsandbytes
_4bit
"
,
False
)
if
use_bitsandbytes
_4bit
:
shard_size
=
loaded_weight
.
shape
[
output_dim
]
shard_offset
=
loaded_weight
.
shape
[
output_dim
]
*
loaded_shard_id
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
...
...
@@ -595,7 +596,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
(
isinstance
(
param
,
PackedvLLMParameter
)
isinstance
(
param
,
(
PackedColumnParameter
,
PackedvLLMParameter
)
)
and
param
.
packed_dim
==
param
.
output_dim
):
shard_size
,
shard_offset
=
param
.
adjust_shard_indexes_for_packing
(
...
...
@@ -617,7 +618,7 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
i
s
BasevLLMParameter
:
elif
type
(
param
)
i
n
(
RowvLLMParameter
,
BasevLLMParameter
)
:
param
.
load_merged_column_weight
(
loaded_weight
=
loaded_weight
)
return
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
...
...
@@ -760,7 +761,7 @@ class QKVParallelLinear(ColumnParallelLinear):
# If quantized, we need to adjust the offset and size to account
# for the packing.
if
(
isinstance
(
param
,
PackedvLLMParameter
)
isinstance
(
param
,
(
PackedColumnParameter
,
PackedvLLMParameter
)
)
and
param
.
packed_dim
==
param
.
output_dim
):
shard_size
,
shard_offset
=
param
.
adjust_shard_indexes_for_packing
(
...
...
@@ -780,10 +781,10 @@ class QKVParallelLinear(ColumnParallelLinear):
):
if
loaded_shard_id
is
None
:
# special case for certain models
if
isinstance
(
param
,
PerTensorScaleParameter
):
param
.
load_
merged_column
_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
param
.
load_
qkv
_weight
(
loaded_weight
=
loaded_weight
,
shard_id
=
0
)
return
elif
type
(
param
)
i
s
BasevLLMParameter
:
param
.
load_
merged_column
_weight
(
loaded_weight
=
loaded_weight
)
elif
type
(
param
)
i
n
(
RowvLLMParameter
,
BasevLLMParameter
)
:
param
.
load_
qkv
_weight
(
loaded_weight
=
loaded_weight
)
return
self
.
_load_fused_module_from_checkpoint
(
param
,
loaded_weight
)
return
...
...
@@ -818,17 +819,22 @@ class QKVParallelLinear(ColumnParallelLinear):
param
.
shard_weight_type
[
loaded_shard_id
]
=
loaded_weight
.
item
()
return
if
is_gguf_weight
and
isinstance
(
param
,
UninitializedParameter
):
from
gguf.constants
import
GGML_QUANT_SIZES
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
tp_rank
=
get_tensor_model_parallel_rank
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_size
=
loaded_weight
.
size
(
output_dim
)
//
tp_size
start_idx
=
tp_rank
*
shard_size
ori_shape
=
param
.
tensor_shape
weight_types
=
self
.
qweight_type
.
shard_weight_type
.
values
()
row_size
=
[]
for
weight_type
in
weight_types
:
block_size
,
type_size
=
GGML_QUANT_SIZES
[
weight_type
]
row_size
.
append
(
ori_shape
[
1
]
//
block_size
*
type_size
)
q_shape
=
(
ori_shape
[
0
],
max
(
row_size
)
)
param
.
materialize
(
q_shape
,
dtype
=
loaded_weight
.
dtype
)
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_id_map
[
loaded_shard_id
]
=
len
(
param
.
data_container
)
param
.
data_container
.
append
(
loaded_weight
)
if
len
(
param
.
data_container
)
==
3
:
self
.
qweight
=
param
.
materialize_nested
(
)
return
param_data
=
param
.
data
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
...
...
@@ -863,6 +869,8 @@ class QKVParallelLinear(ColumnParallelLinear):
self
.
total_num_kv_heads
*
self
.
head_size
,
),
]
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
packed_dim
=
getattr
(
param
,
"packed_dim"
,
None
)
for
shard_id
,
shard_offset
,
shard_size
in
shard_offsets
:
# Special case for Quantized Weights.
...
...
@@ -877,6 +885,29 @@ class QKVParallelLinear(ColumnParallelLinear):
param
,
shard_size
,
shard_offset
)
if
use_bitsandbytes_4bit
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
total_num_heads
*
self
.
head_size
),
"k"
:
(
self
.
total_num_heads
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
"v"
:
(
(
self
.
total_num_heads
+
self
.
total_num_kv_heads
)
*
self
.
head_size
,
self
.
total_num_kv_heads
*
self
.
head_size
,
),
"total"
:
(
(
self
.
total_num_heads
+
2
*
self
.
total_num_kv_heads
)
*
self
.
head_size
,
0
,
),
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_4bit_shard
(
param
,
orig_qkv_offsets
,
shard_id
)
loaded_weight_shard
=
loaded_weight
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
...
...
@@ -910,8 +941,8 @@ class QKVParallelLinear(ColumnParallelLinear):
param
,
shard_size
,
shard_offset
)
use_bitsandbytes
=
getattr
(
param
,
"use_bitsandbytes"
,
False
)
if
use_bitsandbytes
:
use_bitsandbytes
_4bit
=
getattr
(
param
,
"use_bitsandbytes
_4bit
"
,
False
)
if
use_bitsandbytes
_4bit
:
orig_qkv_offsets
=
{
"q"
:
(
0
,
self
.
num_heads
*
self
.
head_size
),
"k"
:
(
...
...
@@ -927,29 +958,22 @@ class QKVParallelLinear(ColumnParallelLinear):
0
,
),
}
shard_size
,
shard_offset
=
adjust_bitsandbytes_shard
(
shard_size
,
shard_offset
=
adjust_bitsandbytes_
4bit_
shard
(
param
,
orig_qkv_offsets
,
loaded_shard_id
)
if
is_gguf_weight
:
tp_size
=
get_tensor_model_parallel_world_size
()
output_dim
=
getattr
(
param
,
"output_dim"
,
None
)
shard_shape
=
list
(
loaded_weight
.
shape
)
shard_shape
[
output_dim
]
=
shard_shape
[
output_dim
]
//
tp_size
param
.
shard_id
.
append
(
loaded_shard_id
)
param
.
shard_size
[
loaded_shard_id
]
=
shard_shape
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
input_size
=
loaded_weight
.
shape
[
input_dim
]
param_data
=
param_data
.
narrow
(
input_dim
,
0
,
input_size
)
param_data
=
param_data
.
narrow
(
output_dim
,
shard_offset
,
shard_size
)
if
loaded_shard_id
==
"q"
:
shard_id
=
tp_rank
else
:
shard_id
=
tp_rank
//
self
.
num_kv_head_replicas
start_idx
=
shard_id
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
not
use_bitsandbytes_4bit
:
loaded_weight
=
loaded_weight
.
narrow
(
output_dim
,
start_idx
,
shard_size
)
# Special case for for AQLM codebooks.
elif
is_metadata
:
# metadata indicates fixed size concatenated along dim 0
...
...
@@ -1037,7 +1061,6 @@ class RowParallelLinear(LinearBase):
if
self
.
quant_method
.
__class__
.
__name__
in
WEIGHT_LOADER_V2_SUPPORTED
else
self
.
weight_loader
),
prefix
=
prefix
,
)
if
not
reduce_results
and
(
bias
and
not
skip_bias_add
):
raise
ValueError
(
...
...
@@ -1061,6 +1084,7 @@ class RowParallelLinear(LinearBase):
tp_rank
=
get_tensor_model_parallel_rank
()
tp_size
=
get_tensor_model_parallel_world_size
()
input_dim
=
getattr
(
param
,
"input_dim"
,
None
)
use_bitsandbytes_4bit
=
getattr
(
param
,
"use_bitsandbytes_4bit"
,
False
)
# Special case for GGUF
is_gguf_weight
=
getattr
(
param
,
"is_gguf_weight"
,
False
)
...
...
@@ -1076,7 +1100,9 @@ class RowParallelLinear(LinearBase):
param
.
materialize
(
tuple
(
weight_shape
),
dtype
=
loaded_weight
.
dtype
)
param_data
=
param
.
data
if
input_dim
is
not
None
:
# bitsandbytes loads the weights of the specific portion
# no need to narrow here
if
input_dim
is
not
None
and
not
use_bitsandbytes_4bit
:
shard_size
=
param_data
.
shape
[
input_dim
]
start_idx
=
tp_rank
*
shard_size
loaded_weight
=
loaded_weight
.
narrow
(
input_dim
,
start_idx
,
shard_size
)
...
...
python/sglang/srt/lora/lora.py
View file @
8bee20f8
...
...
@@ -351,7 +351,9 @@ class LoRAAdapter(nn.Module):
loader
=
DefaultModelLoader
(
self
.
load_config
)
revision
=
getattr
(
self
.
config
.
hf_config
,
"revision"
,
None
)
for
name
,
loaded_weight
in
loader
.
_get_weights_iterator
(
model_path
,
revision
=
revision
,
fall_back_to_pt
=
True
DefaultModelLoader
.
Source
(
model_path
,
revision
=
revision
,
fall_back_to_pt
=
True
)
):
match
=
re
.
search
(
r
"layers\.(\d+)\."
,
name
)
if
match
is
not
None
:
...
...
python/sglang/srt/model_executor/model_runner.py
View file @
8bee20f8
...
...
@@ -59,8 +59,11 @@ from sglang.srt.server_args import ServerArgs
from
sglang.srt.utils
import
(
enable_show_time_cost
,
get_available_gpu_memory
,
is_attention_free_model
,
is_embedding_model
,
is_generation_model
,
is_multimodal_model
,
model_has_inner_state
,
monkey_patch_vllm_dummy_weight_loader
,
monkey_patch_vllm_p2p_access_check
,
)
...
...
@@ -316,11 +319,13 @@ class ModelRunner:
def
get_weight_iter
(
config
):
iter
=
loader
.
_get_weights_iterator
(
config
.
model
,
config
.
revision
,
fall_back_to_pt
=
getattr
(
self
.
model
,
"fall_back_to_pt_during_load"
,
True
),
DefaultModelLoader
.
Source
(
config
.
model
,
revision
=
config
.
revision
,
fall_back_to_pt
=
getattr
(
self
.
model
,
"fall_back_to_pt_during_load"
,
True
),
)
)
return
iter
...
...
@@ -662,3 +667,7 @@ def load_model_cls_srt(model_arch: str) -> Optional[Type[nn.Module]]:
# Monkey patch model loader
setattr
(
ModelRegistry
,
"_try_load_model_cls"
,
load_model_cls_srt
)
setattr
(
ModelRegistry
,
"is_multimodal_model"
,
is_multimodal_model
)
setattr
(
ModelRegistry
,
"is_attention_free_model"
,
is_attention_free_model
)
setattr
(
ModelRegistry
,
"model_has_inner_state"
,
model_has_inner_state
)
setattr
(
ModelRegistry
,
"is_embedding_model"
,
is_embedding_model
)
python/sglang/srt/models/deepseek_v2.py
View file @
8bee20f8
...
...
@@ -250,7 +250,7 @@ class DeepseekV2Attention(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
)
rope_scaling
[
"type"
]
=
"deepseek_yarn"
rope_scaling
[
"
rope_
type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
...
...
@@ -398,7 +398,7 @@ class DeepseekV2AttentionMLA(nn.Module):
bias
=
False
,
quant_config
=
quant_config
,
)
rope_scaling
[
"type"
]
=
"deepseek_yarn"
rope_scaling
[
"
rope_
type"
]
=
"deepseek_yarn"
self
.
rotary_emb
=
get_rope
(
qk_rope_head_dim
,
rotary_dim
=
qk_rope_head_dim
,
...
...
python/sglang/srt/utils.py
View file @
8bee20f8
...
...
@@ -215,6 +215,26 @@ def is_multimodal_model(model_architectures):
return
False
def
is_attention_free_model
(
model_architectures
):
return
False
def
model_has_inner_state
(
model_architectures
):
return
False
def
is_embedding_model
(
model_architectures
):
if
(
"LlamaEmbeddingModel"
in
model_architectures
or
"MistralModel"
in
model_architectures
or
"LlamaForSequenceClassification"
in
model_architectures
or
"LlamaForSequenceClassificationWithNormal_Weights"
in
model_architectures
):
return
True
else
:
return
False
def
is_generation_model
(
model_architectures
,
is_embedding
:
bool
=
False
):
# We have two ways to determine whether a model is a generative model.
# 1. Check the model architectue
...
...
test/srt/test_server_args.py
View file @
8bee20f8
...
...
@@ -11,13 +11,13 @@ class TestPrepareServerArgs(unittest.TestCase):
"--model-path"
,
"model_path"
,
"--json-model-override-args"
,
'{"rope_scaling": {"factor": 2.0, "type": "linear"}}'
,
'{"rope_scaling": {"factor": 2.0, "
rope_
type": "linear"}}'
,
]
)
self
.
assertEqual
(
server_args
.
model_path
,
"model_path"
)
self
.
assertEqual
(
json
.
loads
(
server_args
.
json_model_override_args
),
{
"rope_scaling"
:
{
"factor"
:
2.0
,
"type"
:
"linear"
}},
{
"rope_scaling"
:
{
"factor"
:
2.0
,
"
rope_
type"
:
"linear"
}},
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment