Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
21063c11
Unverified
Commit
21063c11
authored
Nov 06, 2024
by
Aaron Pham
Committed by
GitHub
Nov 06, 2024
Browse files
[CI/Build] drop support for Python 3.8 EOL (#8464)
Signed-off-by:
Aaron Pham
<
contact@aarnphm.xyz
>
parent
4be3a451
Changes
115
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
46 additions
and
48 deletions
+46
-48
tests/samplers/test_rejection_sampler.py
tests/samplers/test_rejection_sampler.py
+4
-6
tests/test_logger.py
tests/test_logger.py
+1
-1
tests/tokenization/test_detokenize.py
tests/tokenization/test_detokenize.py
+2
-2
tools/profiler/print_layerwise_table.py
tools/profiler/print_layerwise_table.py
+1
-1
tools/profiler/visualize_layerwise_profile.py
tools/profiler/visualize_layerwise_profile.py
+1
-1
tools/report_build_time_ninja.py
tools/report_build_time_ninja.py
+17
-15
use_existing_torch.py
use_existing_torch.py
+1
-1
vllm/attention/ops/blocksparse_attention/interface.py
vllm/attention/ops/blocksparse_attention/interface.py
+2
-4
vllm/config.py
vllm/config.py
+4
-3
vllm/core/evictor.py
vllm/core/evictor.py
+1
-1
vllm/distributed/device_communicators/custom_all_reduce_utils.py
...stributed/device_communicators/custom_all_reduce_utils.py
+1
-1
vllm/engine/async_llm_engine.py
vllm/engine/async_llm_engine.py
+1
-1
vllm/engine/llm_engine.py
vllm/engine/llm_engine.py
+2
-2
vllm/engine/metrics_types.py
vllm/engine/metrics_types.py
+1
-1
vllm/engine/output_processor/multi_step.py
vllm/engine/output_processor/multi_step.py
+1
-1
vllm/entrypoints/chat_utils.py
vllm/entrypoints/chat_utils.py
+1
-1
vllm/entrypoints/openai/run_batch.py
vllm/entrypoints/openai/run_batch.py
+1
-1
vllm/executor/ray_gpu_executor.py
vllm/executor/ray_gpu_executor.py
+1
-1
vllm/logger.py
vllm/logger.py
+1
-2
vllm/lora/models.py
vllm/lora/models.py
+2
-2
No files found.
tests/samplers/test_rejection_sampler.py
View file @
21063c11
...
@@ -413,12 +413,10 @@ class _CorrectnessTestHelper:
...
@@ -413,12 +413,10 @@ class _CorrectnessTestHelper:
def
generate_probs_for_test
(
def
generate_probs_for_test
(
self
,
draft_and_target_probs_equal
:
bool
self
,
draft_and_target_probs_equal
:
bool
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
,
torch
.
Tensor
]:
draft_probs
,
target_probs
=
[
draft_probs
,
target_probs
=
(
F
.
softmax
(
F
.
softmax
(
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
torch
.
rand
(
self
.
vocab_size
,
dtype
=
torch
.
float32
),
dim
=-
1
,
dim
=-
1
,
)
for
_
in
range
(
2
)
)
for
_
in
range
(
2
))
]
num_reference_probs
=
100
num_reference_probs
=
100
reference_probs
=
F
.
softmax
(
reference_probs
=
F
.
softmax
(
...
...
tests/test_logger.py
View file @
21063c11
...
@@ -29,7 +29,7 @@ def test_trace_function_call():
...
@@ -29,7 +29,7 @@ def test_trace_function_call():
cur_dir
=
os
.
path
.
dirname
(
__file__
)
cur_dir
=
os
.
path
.
dirname
(
__file__
)
enable_trace_function_call
(
path
,
cur_dir
)
enable_trace_function_call
(
path
,
cur_dir
)
f1
(
1
)
f1
(
1
)
with
open
(
path
,
'r'
)
as
f
:
with
open
(
path
)
as
f
:
content
=
f
.
read
()
content
=
f
.
read
()
assert
"f1"
in
content
assert
"f1"
in
content
...
...
tests/tokenization/test_detokenize.py
View file @
21063c11
...
@@ -93,10 +93,10 @@ def test_mistral_edge_case(tokenizer, truth):
...
@@ -93,10 +93,10 @@ def test_mistral_edge_case(tokenizer, truth):
def
skip_special_tokens
(
request
,
tokenizer_name
)
->
Generator
[
bool
,
Any
,
None
]:
def
skip_special_tokens
(
request
,
tokenizer_name
)
->
Generator
[
bool
,
Any
,
None
]:
if
"mistral"
in
tokenizer_name
:
if
"mistral"
in
tokenizer_name
:
yield
(
yield
(
bool
(
True
)
if
request
.
param
else
True
if
request
.
param
else
pytest
.
skip
(
"mistral doesn't support skip_special_tokens=False"
))
pytest
.
skip
(
"mistral doesn't support skip_special_tokens=False"
))
else
:
else
:
yield
bool
(
True
)
if
request
.
param
else
bool
(
False
)
yield
bool
(
request
.
param
)
@
pytest
.
mark
.
parametrize
(
"truth"
,
TRUTH
)
@
pytest
.
mark
.
parametrize
(
"truth"
,
TRUTH
)
...
...
tools/profiler/print_layerwise_table.py
View file @
21063c11
...
@@ -46,7 +46,7 @@ if __name__ == "__main__":
...
@@ -46,7 +46,7 @@ if __name__ == "__main__":
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
with
open
(
args
.
json_trace
,
"r"
)
as
f
:
with
open
(
args
.
json_trace
)
as
f
:
profile_data
=
json
.
load
(
f
)
profile_data
=
json
.
load
(
f
)
if
args
.
table
==
"summary"
:
if
args
.
table
==
"summary"
:
...
...
tools/profiler/visualize_layerwise_profile.py
View file @
21063c11
...
@@ -434,7 +434,7 @@ def main(
...
@@ -434,7 +434,7 @@ def main(
f
"
{
', Sparsity '
+
sparsity
if
sparsity
else
''
}
"
)
f
"
{
', Sparsity '
+
sparsity
if
sparsity
else
''
}
"
)
profile_json
=
None
profile_json
=
None
with
open
(
json_trace
,
"r"
)
as
f
:
with
open
(
json_trace
)
as
f
:
profile_json
=
json
.
load
(
f
)
profile_json
=
json
.
load
(
f
)
assert
profile_json
is
not
None
assert
profile_json
is
not
None
...
...
tools/report_build_time_ninja.py
View file @
21063c11
...
@@ -81,7 +81,7 @@ class Target:
...
@@ -81,7 +81,7 @@ class Target:
# Allow for modest floating-point errors
# Allow for modest floating-point errors
epsilon
=
0.000002
epsilon
=
0.000002
if
(
self
.
weighted_duration
>
self
.
Duration
()
+
epsilon
):
if
(
self
.
weighted_duration
>
self
.
Duration
()
+
epsilon
):
print
(
'
%s
>
%s?'
%
(
self
.
weighted_duration
,
self
.
Duration
()))
print
(
'
{}
>
{}?'
.
format
(
self
.
weighted_duration
,
self
.
Duration
()))
assert
(
self
.
weighted_duration
<=
self
.
Duration
()
+
epsilon
)
assert
(
self
.
weighted_duration
<=
self
.
Duration
()
+
epsilon
)
return
self
.
weighted_duration
return
self
.
weighted_duration
...
@@ -104,7 +104,7 @@ def ReadTargets(log, show_all):
...
@@ -104,7 +104,7 @@ def ReadTargets(log, show_all):
The result is a list of Target objects."""
The result is a list of Target objects."""
header
=
log
.
readline
()
header
=
log
.
readline
()
assert
header
==
'# ninja log v5
\n
'
,
\
assert
header
==
'# ninja log v5
\n
'
,
\
'unrecognized ninja log version
%r'
%
header
'unrecognized ninja log version
{!r}'
.
format
(
header
)
targets_dict
=
{}
targets_dict
=
{}
last_end_seen
=
0.0
last_end_seen
=
0.0
for
line
in
log
:
for
line
in
log
:
...
@@ -254,8 +254,8 @@ def SummarizeEntries(entries, extra_step_types):
...
@@ -254,8 +254,8 @@ def SummarizeEntries(entries, extra_step_types):
# Warn if the sum of weighted times is off by more than half a second.
# Warn if the sum of weighted times is off by more than half a second.
if
abs
(
length
-
weighted_total
)
>
500
:
if
abs
(
length
-
weighted_total
)
>
500
:
print
(
'Warning: Possible corrupt ninja log, results may be '
print
(
'Warning: Possible corrupt ninja log, results may be '
'untrustworthy. Length =
%
.3f, weighted total =
%
.3f
'
%
'untrustworthy. Length =
{:
.3f
}
, weighted total =
{:
.3f
}'
.
format
(
(
length
,
weighted_total
))
length
,
weighted_total
))
entries_by_ext
=
defaultdict
(
list
)
entries_by_ext
=
defaultdict
(
list
)
for
target
in
entries
:
for
target
in
entries
:
...
@@ -263,16 +263,17 @@ def SummarizeEntries(entries, extra_step_types):
...
@@ -263,16 +263,17 @@ def SummarizeEntries(entries, extra_step_types):
entries_by_ext
[
extension
].
append
(
target
)
entries_by_ext
[
extension
].
append
(
target
)
for
key
,
values
in
entries_by_ext
.
items
():
for
key
,
values
in
entries_by_ext
.
items
():
print
(
' Longest build steps for
%s:'
%
key
)
print
(
' Longest build steps for
{}:'
.
format
(
key
)
)
values
.
sort
(
key
=
lambda
x
:
x
.
WeightedDuration
())
values
.
sort
(
key
=
lambda
x
:
x
.
WeightedDuration
())
for
target
in
values
[
-
long_count
:]:
for
target
in
values
[
-
long_count
:]:
print
(
' %8.1f weighted s to build %s (%.1f s elapsed time)'
%
print
(
(
target
.
WeightedDuration
(),
target
.
DescribeTargets
(),
' {:8.1f} weighted s to build {} ({:.1f} s elapsed time)'
.
format
(
target
.
WeightedDuration
(),
target
.
DescribeTargets
(),
target
.
Duration
()))
target
.
Duration
()))
print
(
'
%
.1f s weighted time (
%
.1f s elapsed time sum,
%
1.1fx '
print
(
'
{:
.1f
}
s weighted time (
{:
.1f
}
s elapsed time sum,
{:
1.1f
}
x '
'parallelism)'
%
'parallelism)'
.
format
(
length
,
total_cpu_time
,
(
length
,
total_cpu_time
,
total_cpu_time
*
1.0
/
length
))
total_cpu_time
*
1.0
/
length
))
print
(
' %d build steps completed, average of %1.2f/s'
%
print
(
' %d build steps completed, average of %1.2f/s'
%
(
len
(
entries
),
len
(
entries
)
/
(
length
)))
(
len
(
entries
),
len
(
entries
)
/
(
length
)))
...
@@ -298,11 +299,12 @@ def main():
...
@@ -298,11 +299,12 @@ def main():
long_ext_count
+=
len
(
args
.
step_types
.
split
(
';'
))
long_ext_count
+=
len
(
args
.
step_types
.
split
(
';'
))
try
:
try
:
with
open
(
log_file
,
'r'
)
as
log
:
with
open
(
log_file
)
as
log
:
entries
=
ReadTargets
(
log
,
False
)
entries
=
ReadTargets
(
log
,
False
)
SummarizeEntries
(
entries
,
args
.
step_types
)
SummarizeEntries
(
entries
,
args
.
step_types
)
except
IOError
:
except
OSError
:
print
(
'Log file %r not found, no build summary created.'
%
log_file
)
print
(
'Log file {!r} not found, no build summary created.'
.
format
(
log_file
))
return
errno
.
ENOENT
return
errno
.
ENOENT
...
...
use_existing_torch.py
View file @
21063c11
...
@@ -4,7 +4,7 @@ requires_files = glob.glob('requirements*.txt')
...
@@ -4,7 +4,7 @@ requires_files = glob.glob('requirements*.txt')
requires_files
+=
[
"pyproject.toml"
]
requires_files
+=
[
"pyproject.toml"
]
for
file
in
requires_files
:
for
file
in
requires_files
:
print
(
f
">>> cleaning
{
file
}
"
)
print
(
f
">>> cleaning
{
file
}
"
)
with
open
(
file
,
'r'
)
as
f
:
with
open
(
file
)
as
f
:
lines
=
f
.
readlines
()
lines
=
f
.
readlines
()
if
"torch"
in
""
.
join
(
lines
).
lower
():
if
"torch"
in
""
.
join
(
lines
).
lower
():
print
(
"removed:"
)
print
(
"removed:"
)
...
...
vllm/attention/ops/blocksparse_attention/interface.py
View file @
21063c11
...
@@ -192,10 +192,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
...
@@ -192,10 +192,8 @@ class LocalStridedBlockSparseAttn(torch.nn.Module):
attn_mask
=
self
.
dense_attn_mask
[
None
,
:,
:
maxlen
,
:
maxlen
]
attn_mask
=
self
.
dense_attn_mask
[
None
,
:,
:
maxlen
,
:
maxlen
]
q2
=
self
.
transpose_and_pad
(
q
,
cu_seqlens
,
maxlen
,
1
)
q2
=
self
.
transpose_and_pad
(
q
,
cu_seqlens
,
maxlen
,
1
)
k2
,
v2
=
[
k2
,
v2
=
(
self
.
transpose_and_pad
(
x
,
cu_seqlens
,
maxlen
,
q_k_ratio
)
self
.
transpose_and_pad
(
x
,
cu_seqlens
,
maxlen
,
q_k_ratio
)
for
x
in
[
k
,
v
])
for
x
in
[
k
,
v
]
]
spda_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
spda_output
=
torch
.
nn
.
functional
.
scaled_dot_product_attention
(
q2
,
k2
,
v2
,
attn_mask
=
attn_mask
,
scale
=
sm_scale
)
q2
,
k2
,
v2
,
attn_mask
=
attn_mask
,
scale
=
sm_scale
)
return
self
.
transpose_and_unpad
(
spda_output
,
cu_seqlens
)
return
self
.
transpose_and_unpad
(
spda_output
,
cu_seqlens
)
...
...
vllm/config.py
View file @
21063c11
...
@@ -668,9 +668,10 @@ class ModelConfig:
...
@@ -668,9 +668,10 @@ class ModelConfig:
@
property
@
property
def
is_encoder_decoder_model
(
self
)
->
bool
:
def
is_encoder_decoder_model
(
self
)
->
bool
:
"""Extract the HF encoder/decoder model flag."""
"""Extract the HF encoder/decoder model flag."""
return
getattr
(
self
.
hf_config
,
"is_encoder_decoder"
,
False
)
or
(
return
getattr
(
(
hasattr
(
self
.
hf_config
,
"text_config"
)
and
getattr
(
self
.
hf_config
,
"is_encoder_decoder"
,
self
.
hf_config
.
text_config
,
"is_encoder_decoder"
,
False
)))
False
)
or
(
hasattr
(
self
.
hf_config
,
"text_config"
)
and
getattr
(
self
.
hf_config
.
text_config
,
"is_encoder_decoder"
,
False
))
@
property
@
property
def
is_multimodal_model
(
self
)
->
bool
:
def
is_multimodal_model
(
self
)
->
bool
:
...
...
vllm/core/evictor.py
View file @
21063c11
...
@@ -52,7 +52,7 @@ class Evictor(ABC):
...
@@ -52,7 +52,7 @@ class Evictor(ABC):
pass
pass
class
BlockMetaData
()
:
class
BlockMetaData
:
"""Data structure for storing key data describe cached block, so that
"""Data structure for storing key data describe cached block, so that
evitor could use to make its decision which one to choose for eviction
evitor could use to make its decision which one to choose for eviction
...
...
vllm/distributed/device_communicators/custom_all_reduce_utils.py
View file @
21063c11
...
@@ -240,7 +240,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
...
@@ -240,7 +240,7 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
if
is_distributed
:
if
is_distributed
:
get_world_group
().
barrier
()
get_world_group
().
barrier
()
logger
.
info
(
"reading GPU P2P access cache from %s"
,
path
)
logger
.
info
(
"reading GPU P2P access cache from %s"
,
path
)
with
open
(
path
,
"r"
)
as
f
:
with
open
(
path
)
as
f
:
cache
=
json
.
load
(
f
)
cache
=
json
.
load
(
f
)
_gpu_p2p_access_cache
=
cache
_gpu_p2p_access_cache
=
cache
return
_gpu_p2p_access_cache
[
f
"
{
src
}
->
{
tgt
}
"
]
return
_gpu_p2p_access_cache
[
f
"
{
src
}
->
{
tgt
}
"
]
...
...
vllm/engine/async_llm_engine.py
View file @
21063c11
...
@@ -812,7 +812,7 @@ class AsyncLLMEngine(EngineClient):
...
@@ -812,7 +812,7 @@ class AsyncLLMEngine(EngineClient):
async
def
run_engine_loop
(
engine_ref
:
ReferenceType
):
async
def
run_engine_loop
(
engine_ref
:
ReferenceType
):
"""We use a weakref to the engine so that the running loop
"""We use a weakref to the engine so that the running loop
doesn't prevent the engine being garbage collected."""
doesn't prevent the engine being garbage collected."""
engine
:
Optional
[
"
AsyncLLMEngine
"
]
=
engine_ref
()
engine
:
Optional
[
AsyncLLMEngine
]
=
engine_ref
()
if
not
engine
:
if
not
engine
:
return
return
...
...
vllm/engine/llm_engine.py
View file @
21063c11
...
@@ -1541,8 +1541,8 @@ class LLMEngine:
...
@@ -1541,8 +1541,8 @@ class LLMEngine:
seq_group
.
state
.
remaining_steps
!=
ref_remaining_steps
seq_group
.
state
.
remaining_steps
!=
ref_remaining_steps
for
seq_group
in
seq_group_metadata_list
[
1
:]
for
seq_group
in
seq_group_metadata_list
[
1
:]
]):
]):
raise
AssertionError
(
(
"All running sequence groups should "
raise
AssertionError
(
"All running sequence groups should "
"have the same remaining steps."
)
)
"have the same remaining steps."
)
return
ref_remaining_steps
>
0
return
ref_remaining_steps
>
0
...
...
vllm/engine/metrics_types.py
View file @
21063c11
...
@@ -77,7 +77,7 @@ class StatLoggerBase(ABC):
...
@@ -77,7 +77,7 @@ class StatLoggerBase(ABC):
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
num_generation_tokens
:
List
[
int
]
=
[]
self
.
last_local_log
=
time
.
time
()
self
.
last_local_log
=
time
.
time
()
self
.
local_interval
=
local_interval
self
.
local_interval
=
local_interval
self
.
spec_decode_metrics
:
Optional
[
"
SpecDecodeWorkerMetrics
"
]
=
None
self
.
spec_decode_metrics
:
Optional
[
SpecDecodeWorkerMetrics
]
=
None
@
abstractmethod
@
abstractmethod
def
log
(
self
,
stats
:
Stats
)
->
None
:
def
log
(
self
,
stats
:
Stats
)
->
None
:
...
...
vllm/engine/output_processor/multi_step.py
View file @
21063c11
...
@@ -63,7 +63,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
...
@@ -63,7 +63,7 @@ class MultiStepOutputProcessor(SequenceGroupOutputProcessor):
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
single_step_process_prompt_logprob
(
self
,
seq_group
,
output
)
@
staticmethod
@
staticmethod
@
functools
.
lru_cache
()
@
functools
.
lru_cache
def
_log_prompt_logprob_unsupported_warning_once
():
def
_log_prompt_logprob_unsupported_warning_once
():
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# Reminder: Please update docs/source/serving/compatibility_matrix.rst
# If the feature combo become valid
# If the feature combo become valid
...
...
vllm/entrypoints/chat_utils.py
View file @
21063c11
...
@@ -362,7 +362,7 @@ def load_chat_template(
...
@@ -362,7 +362,7 @@ def load_chat_template(
if
chat_template
is
None
:
if
chat_template
is
None
:
return
None
return
None
try
:
try
:
with
open
(
chat_template
,
"r"
)
as
f
:
with
open
(
chat_template
)
as
f
:
resolved_chat_template
=
f
.
read
()
resolved_chat_template
=
f
.
read
()
except
OSError
as
e
:
except
OSError
as
e
:
if
isinstance
(
chat_template
,
Path
):
if
isinstance
(
chat_template
,
Path
):
...
...
vllm/entrypoints/openai/run_batch.py
View file @
21063c11
...
@@ -120,7 +120,7 @@ async def read_file(path_or_url: str) -> str:
...
@@ -120,7 +120,7 @@ async def read_file(path_or_url: str) -> str:
session
.
get
(
path_or_url
)
as
resp
:
session
.
get
(
path_or_url
)
as
resp
:
return
await
resp
.
text
()
return
await
resp
.
text
()
else
:
else
:
with
open
(
path_or_url
,
"r"
,
encoding
=
"utf-8"
)
as
f
:
with
open
(
path_or_url
,
encoding
=
"utf-8"
)
as
f
:
return
f
.
read
()
return
f
.
read
()
...
...
vllm/executor/ray_gpu_executor.py
View file @
21063c11
...
@@ -32,7 +32,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
...
@@ -32,7 +32,7 @@ class RayGPUExecutor(DistributedGPUExecutor):
uses_ray
:
bool
=
True
uses_ray
:
bool
=
True
def
_init_executor
(
self
)
->
None
:
def
_init_executor
(
self
)
->
None
:
self
.
forward_dag
:
Optional
[
"
ray.dag.CompiledDAG
"
]
=
None
self
.
forward_dag
:
Optional
[
ray
.
dag
.
CompiledDAG
]
=
None
# If the env var is set, it uses the Ray's compiled DAG API
# If the env var is set, it uses the Ray's compiled DAG API
# which optimizes the control plane overhead.
# which optimizes the control plane overhead.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
# Run vLLM with VLLM_USE_RAY_COMPILED_DAG=1 to enable it.
...
...
vllm/logger.py
View file @
21063c11
...
@@ -67,8 +67,7 @@ def _configure_vllm_root_logger() -> None:
...
@@ -67,8 +67,7 @@ def _configure_vllm_root_logger() -> None:
raise
RuntimeError
(
raise
RuntimeError
(
"Could not load logging config. File does not exist: %s"
,
"Could not load logging config. File does not exist: %s"
,
VLLM_LOGGING_CONFIG_PATH
)
VLLM_LOGGING_CONFIG_PATH
)
with
open
(
VLLM_LOGGING_CONFIG_PATH
,
encoding
=
"utf-8"
,
with
open
(
VLLM_LOGGING_CONFIG_PATH
,
encoding
=
"utf-8"
)
as
file
:
mode
=
"r"
)
as
file
:
custom_config
=
json
.
loads
(
file
.
read
())
custom_config
=
json
.
loads
(
file
.
read
())
if
not
isinstance
(
custom_config
,
dict
):
if
not
isinstance
(
custom_config
,
dict
):
...
...
vllm/lora/models.py
View file @
21063c11
...
@@ -343,7 +343,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -343,7 +343,7 @@ class LoRAModelManager(AdapterModelManager):
# text modules (e.g. ChatGLM)
# text modules (e.g. ChatGLM)
and
hasattr
(
self
.
model
,
"get_mm_mapping"
))
and
hasattr
(
self
.
model
,
"get_mm_mapping"
))
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
self
.
packed_modules
:
Dict
[
str
,
List
[
str
]]
=
{}
self
.
modules
:
Dict
[
str
,
"
BaseLayerWithLoRA
"
]
=
{}
self
.
modules
:
Dict
[
str
,
BaseLayerWithLoRA
]
=
{}
# Dict instead of a Set for compatibility with LRUCache.
# Dict instead of a Set for compatibility with LRUCache.
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_last_mapping
:
Optional
[
LoRAMapping
]
=
None
self
.
_create_lora_modules
()
self
.
_create_lora_modules
()
...
@@ -548,7 +548,7 @@ class LoRAModelManager(AdapterModelManager):
...
@@ -548,7 +548,7 @@ class LoRAModelManager(AdapterModelManager):
else
:
else
:
parts
=
module_name
.
split
(
"."
)
parts
=
module_name
.
split
(
"."
)
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
replacements
=
self
.
packed_modules_mapping
[
parts
[
-
1
]]
subloras
:
List
[
Optional
[
"
LoRALayerWeights
"
]]
=
[]
subloras
:
List
[
Optional
[
LoRALayerWeights
]]
=
[]
for
i
,
r
in
enumerate
(
replacements
):
for
i
,
r
in
enumerate
(
replacements
):
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
lora
=
LoRALayerWeights
.
create_dummy_lora_weights
(
module_name
+
"."
+
r
,
module_name
+
"."
+
r
,
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment