Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
c9700db4
"googlemock/vscode:/vscode.git/clone" did not exist on "9c2293af064504f1a7296a2397211be8809452d9"
Unverified
Commit
c9700db4
authored
Dec 18, 2023
by
pppppM
Committed by
GitHub
Dec 18, 2023
Browse files
Fix meta tensor error in `lite` module(#848)
parent
e3ac7fd5
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
33 deletions
+13
-33
lmdeploy/lite/utils/load.py
lmdeploy/lite/utils/load.py
+13
-33
No files found.
lmdeploy/lite/utils/load.py
View file @
c9700db4
# Copyright (c) OpenMMLab. All rights reserved.
import
torch
from
accelerate
import
infer_auto_device_map
,
init_empty_weights
from
transformers
import
AutoConfig
,
AutoModelForCausalLM
from
lmdeploy.lite.utils
import
collect_target_modules
from
lmdeploy.pytorch.model
import
LoadWoInit
LAYER_TYPE_MAP
=
{
'InternLMForCausalLM'
:
'InternLMDecoderLayer'
,
'QWenLMHeadModel'
:
'QWenBlock'
,
'BaiChuanForCausalLM'
:
'DecoderLayer'
,
# Baichuan 7B
'BaichuanForCausalLM'
:
'DecoderLayer'
,
# Baichuan2 7B
'LlamaForCausalLM'
:
'LlamaDecoderLayer'
,
}
def
load_hf_from_pretrained
(
pretrained_model_name_or_path
,
dtype
=
torch
.
float16
,
**
kwargs
):
def
load_hf_from_pretrained
(
pretrained_model_name_or_path
,
**
kwargs
):
if
dtype
==
torch
.
bfloat16
and
not
torch
.
cuda
.
is_bf16_supported
():
raise
RuntimeError
(
'Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)'
)
kwargs
.
pop
(
'config'
,
None
)
hf_config
=
AutoConfig
.
from_pretrained
(
pretrained_model_name_or_path
,
torch_dtype
=
torch
.
float16
,
torch_dtype
=
dtype
,
trust_remote_code
=
True
)
# hard code for qwen, other configs do not have the `fp16` attribute.
# HACK hard code for qwen, other configs do not have the `fp16` attribute.
if
dtype
==
torch
.
float16
:
hf_config
.
fp16
=
True
elif
dtype
==
torch
.
bfloat16
:
hf_config
.
bf16
=
True
with
init_empty_weights
():
with
LoadWoInit
():
# Load model
model
=
AutoModelForCausalLM
.
from_pretrained
(
pretrained_model_name_or_path
,
config
=
hf_config
,
**
kwargs
)
model
.
config
.
use_cache
=
False
layer_type
=
LAYER_TYPE_MAP
[
type
(
model
).
__name__
]
decoder_layers
=
collect_target_modules
(
model
,
layer_type
)
# Infer device map
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
layer_type
])
for
name
in
device_map
.
keys
():
if
name
in
decoder_layers
or
'lm_head'
in
name
:
device_map
[
name
]
=
'cpu'
else
:
device_map
[
name
]
=
0
if
'device_map'
in
kwargs
:
kwargs
.
pop
(
'device_map'
)
with
LoadWoInit
():
model
=
AutoModelForCausalLM
.
from_pretrained
(
pretrained_model_name_or_path
,
device_map
=
device_map
,
config
=
hf_config
,
**
kwargs
)
model
.
config
.
use_cache
=
False
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment