Unverified Commit 09c73fb2 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix multi-GPU loading and inference (#190)

parent 299c460b
......@@ -45,13 +45,13 @@ class AutoAWQForCausalLM:
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=True,
max_memory=None, offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM:
device_map="balanced", offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder,
device_map=device_map, offload_folder=offload_folder,
**config_kwargs
)
......@@ -14,8 +14,12 @@ from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map
from accelerate.big_modeling import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
from accelerate.utils import get_balanced_memory
class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config):
......@@ -109,9 +113,17 @@ class BaseAWQForCausalLM(nn.Module):
with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Evenly distribute memory on GPUs
max_memory = get_balanced_memory(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Get device map
device_map = infer_auto_device_map(
model,
max_memory=max_memory,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
......@@ -123,6 +135,7 @@ class BaseAWQForCausalLM(nn.Module):
trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype,
use_safetensors=safetensors,
device_map=device_map,
**model_init_kwargs
)
......@@ -135,7 +148,7 @@ class BaseAWQForCausalLM(nn.Module):
max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=True, is_quantized=True,
fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None,
device_map="balanced", offload_folder=None,
**config_kwargs):
# [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config(
......@@ -153,36 +166,21 @@ class BaseAWQForCausalLM(nn.Module):
model.tie_weights()
# Get device map
device_map = infer_auto_device_map(
model,
no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype
)
# Load checkpoint
load_checkpoint_in_model(
# loads the weights into modules and distributes
# across available devices automatically
load_checkpoint_and_dispatch(
model,
checkpoint=model_weights_path,
device_map=device_map,
no_split_module_classes=[self.layer_type],
offload_folder=offload_folder,
dtype=torch_dtype
dtype=torch_dtype,
)
# Dispath to devices
if fuse_layers:
self.fuse_layers(model)
# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=offload_folder
)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_config(self, model_path, model_filename, safetensors=True,
......
......@@ -19,6 +19,7 @@ class LlamaLikeBlock(nn.Module):
).to(dev)
self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev)
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
......@@ -30,7 +31,7 @@ class LlamaLikeBlock(nn.Module):
attention_mask=attention_mask
)
h = hidden_states + attn_output
h = hidden_states.to(attn_output.device) + attn_output
out = h + self.mlp.forward(self.norm_2(h))
return out, None, past_key_value
......@@ -48,6 +49,7 @@ class MPTBlock(nn.Module):
).to(dev)
self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev)
self.device = dev
def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
......@@ -62,7 +64,7 @@ class MPTBlock(nn.Module):
use_cache=True
)
h = hidden_states + attn_output
h = hidden_states.to(attn_output.device) + attn_output
out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value
......@@ -94,6 +96,7 @@ class FalconDecoderLayer(nn.Module):
self.input_layernorm = input_layernorm # before attention
self.mlp = mlp
self.device = dev
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
......@@ -136,7 +139,7 @@ class FalconDecoderLayer(nn.Module):
use_cache=True
)
h_attn = hidden_states + attn_output
h_attn = hidden_states.to(attn_output.device) + attn_output
if self.new_decoder_arch:
h_mlp = self.mlp.forward(mlp_layernorm_out)
......
import torch
import torch.nn as nn
from typing import List
from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
from awq.utils.fused_utils import prepare_attention_mask, prepare_input_ids, prepare_cache
class LlamaLikeModel(nn.Module):
"""
......@@ -20,17 +20,17 @@ class LlamaLikeModel(nn.Module):
@torch.inference_mode()
def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen)
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids)
mask = prepare_attention_mask(
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
......@@ -38,7 +38,17 @@ class LlamaLikeModel(nn.Module):
)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
)
h = self.norm(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
......@@ -56,17 +66,17 @@ class MPTModel(nn.Module):
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen)
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.wte(input_ids)
mask = prepare_attention_mask(
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
......@@ -74,7 +84,17 @@ class MPTModel(nn.Module):
)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
)
h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
......@@ -92,17 +112,17 @@ class FalconModel(nn.Module):
@torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids(
input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids,
self.last_forward_num_tokens
)
_bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen)
fused_utils.prepare_cache(self.blocks, seqlen)
h = self.word_embeddings(input_ids)
mask = prepare_attention_mask(
mask = fused_utils.prepare_attention_mask(
seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device,
......@@ -110,7 +130,17 @@ class FalconModel(nn.Module):
)
for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal)
h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
)
h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
import torch
from typing import List
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
def prepare_correct_devices(next_layer, hidden_states, mask):
hidden_states = hidden_states.to(next_layer.device)
if mask is not None:
mask = mask.to(next_layer.device)
return hidden_states, mask
def prepare_cache(blocks, seqlen: int) -> int:
for block in blocks:
start_pos = block.attn.start_pos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment