Unverified Commit c9700db4 authored by pppppM's avatar pppppM Committed by GitHub
Browse files

Fix meta tensor error in `lite` module(#848)

parent e3ac7fd5
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from accelerate import infer_auto_device_map, init_empty_weights
from transformers import AutoConfig, AutoModelForCausalLM
from lmdeploy.lite.utils import collect_target_modules
from lmdeploy.pytorch.model import LoadWoInit
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer',
}
def load_hf_from_pretrained(pretrained_model_name_or_path,
dtype=torch.float16,
**kwargs):
def load_hf_from_pretrained(pretrained_model_name_or_path, **kwargs):
if dtype == torch.bfloat16 and not torch.cuda.is_bf16_supported():
raise RuntimeError('Your device does not supports bf16(bfloat16), '
'please change to fp16(float16)')
kwargs.pop('config', None)
hf_config = AutoConfig.from_pretrained(pretrained_model_name_or_path,
torch_dtype=torch.float16,
torch_dtype=dtype,
trust_remote_code=True)
# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True
# HACK hard code for qwen, other configs do not have the `fp16` attribute.
if dtype == torch.float16:
hf_config.fp16 = True
elif dtype == torch.bfloat16:
hf_config.bf16 = True
with init_empty_weights():
with LoadWoInit():
# Load model
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path, config=hf_config, **kwargs)
model.config.use_cache = False
layer_type = LAYER_TYPE_MAP[type(model).__name__]
decoder_layers = collect_target_modules(model, layer_type)
# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
if 'device_map' in kwargs:
kwargs.pop('device_map')
with LoadWoInit():
model = AutoModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device_map=device_map,
config=hf_config,
**kwargs)
model.config.use_cache = False
return model
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment