Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
c9700db4
Unverified
Commit
c9700db4
authored
Dec 18, 2023
by
pppppM
Committed by
GitHub
Dec 18, 2023
Browse files
Fix meta tensor error in `lite` module(#848)
parent
e3ac7fd5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
33 deletions
+13
-33
lmdeploy/lite/utils/load.py
lmdeploy/lite/utils/load.py
+13
-33
No files found.
lmdeploy/lite/utils/load.py
View file @
c9700db4
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
accelerate
import
infer_auto_device_map
,
init_empty_weights
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
from
lmdeploy.lite.utils
import
collect_target_modules
from
lmdeploy.pytorch.model
import
LoadWoInit
LAYER_TYPE_MAP
=
{
'InternLMForCausalLM'
:
'InternLMDecoderLayer'
,
'QWenLMHeadModel'
:
'QWenBlock'
,
'BaiChuanForCausalLM'
:
'DecoderLayer'
,
# Baichuan 7B
'BaichuanForCausalLM'
:
'DecoderLayer'
,
# Baichuan2 7B
'LlamaForCausalLM'
:
'LlamaDecoderLayer'
,
}
def
load_hf_from_pretrained
(
pretrained_model_name_or_path
,
dtype
=
torch
.
float16
,
**
kwargs
):
def
load_hf_from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
):
if
dtype
==
torch
.
bfloat16
and
not
torch
.
cuda
.
is_bf16_supported
():
raise
RuntimeError
(
'Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)'
)
kwargs
.
pop
(
'config'
,
None
)
hf_config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
dtype
,
trust_remote_code
=
True
)
# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config
.
fp16
=
True
# HACK hard code for qwen, other configs do not have the `fp16` attribute.
if
dtype
==
torch
.
float16
:
hf_config
.
fp16
=
True
elif
dtype
==
torch
.
bfloat16
:
hf_config
.
bf16
=
True
with
init_empty_weights
():
with
LoadWoInit
():
# Load model
model
=
AutoModelForCausalLM
.
from_pretrained
(
pretrained_model_name_or_path
,
config
=
hf_config
,
**
kwargs
)
model
.
config
.
use_cache
=
False
layer_type
=
LAYER_TYPE_MAP
[
type
(
model
).
__name__
]
decoder_layers
=
collect_target_modules
(
model
,
layer_type
)
# Infer device map
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
layer_type
])
for
name
in
device_map
.
keys
():
if
name
in
decoder_layers
or
'lm_head'
in
name
:
device_map
[
name
]
=
'cpu'
else
:
device_map
[
name
]
=
0
if
'device_map'
in
kwargs
:
kwargs
.
pop
(
'device_map'
)
with
LoadWoInit
():
model
=
AutoModelForCausalLM
.
from_pretrained
(
pretrained_model_name_or_path
,
device_map
=
device_map
,
config
=
hf_config
,
**
kwargs
)
model
.
config
.
use_cache
=
False
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment