Unverified Commit 09c73fb2 authored by Casper's avatar Casper Committed by GitHub
Browse files

Fix multi-GPU loading and inference (#190)

parent 299c460b
...@@ -45,13 +45,13 @@ class AutoAWQForCausalLM: ...@@ -45,13 +45,13 @@ class AutoAWQForCausalLM:
def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None, def from_quantized(self, quant_path, quant_filename='', max_new_tokens=None,
trust_remote_code=True, fuse_layers=True, trust_remote_code=True, fuse_layers=True,
batch_size=1, safetensors=True, batch_size=1, safetensors=True,
max_memory=None, offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM: device_map="balanced", offload_folder=None, **config_kwargs) -> BaseAWQForCausalLM:
os.environ["AWQ_BATCH_SIZE"] = str(batch_size) os.environ["AWQ_BATCH_SIZE"] = str(batch_size)
model_type = check_and_get_model_type(quant_path, trust_remote_code) model_type = check_and_get_model_type(quant_path, trust_remote_code)
return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized( return AWQ_CAUSAL_LM_MODEL_MAP[model_type].from_quantized(
quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code, quant_path, model_type, quant_filename, max_new_tokens, trust_remote_code=trust_remote_code,
fuse_layers=fuse_layers, safetensors=safetensors, fuse_layers=fuse_layers, safetensors=safetensors,
max_memory=max_memory, offload_folder=offload_folder, device_map=device_map, offload_folder=offload_folder,
**config_kwargs **config_kwargs
) )
...@@ -14,8 +14,12 @@ from transformers.modeling_utils import shard_checkpoint ...@@ -14,8 +14,12 @@ from transformers.modeling_utils import shard_checkpoint
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
from awq.utils.module import get_named_linears, set_op_by_name from awq.utils.module import get_named_linears, set_op_by_name
from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel from transformers import AutoModelForCausalLM, AutoConfig, PreTrainedModel
from accelerate import init_empty_weights, load_checkpoint_in_model, infer_auto_device_map from accelerate.big_modeling import (
init_empty_weights,
infer_auto_device_map,
load_checkpoint_and_dispatch,
)
from accelerate.utils import get_balanced_memory
class BaseAWQForCausalLM(nn.Module): class BaseAWQForCausalLM(nn.Module):
def __init__(self, model, model_type, is_quantized, quant_config): def __init__(self, model, model_type, is_quantized, quant_config):
...@@ -109,10 +113,18 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -109,10 +113,18 @@ class BaseAWQForCausalLM(nn.Module):
with init_empty_weights(): with init_empty_weights():
model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code) model = AutoModelForCausalLM.from_config(config=config, torch_dtype=torch_dtype, trust_remote_code=trust_remote_code)
# Evenly distribute memory on GPUs
max_memory = get_balanced_memory(
model,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype
)
# Get device map # Get device map
device_map = infer_auto_device_map( device_map = infer_auto_device_map(
model, model,
no_split_module_classes=[self.layer_type], max_memory=max_memory,
no_split_module_classes=[self.layer_type],
dtype=torch_dtype dtype=torch_dtype
) )
del model del model
...@@ -123,6 +135,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -123,6 +135,7 @@ class BaseAWQForCausalLM(nn.Module):
trust_remote_code=trust_remote_code, trust_remote_code=trust_remote_code,
torch_dtype=torch_dtype, torch_dtype=torch_dtype,
use_safetensors=safetensors, use_safetensors=safetensors,
device_map=device_map,
**model_init_kwargs **model_init_kwargs
) )
...@@ -135,7 +148,7 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -135,7 +148,7 @@ class BaseAWQForCausalLM(nn.Module):
max_new_tokens=None, torch_dtype=torch.float16, max_new_tokens=None, torch_dtype=torch.float16,
trust_remote_code=True, safetensors=True, is_quantized=True, trust_remote_code=True, safetensors=True, is_quantized=True,
fuse_layers=False, version='GEMM', fuse_layers=False, version='GEMM',
max_memory=None, offload_folder=None, device_map="balanced", offload_folder=None,
**config_kwargs): **config_kwargs):
# [STEP 1-2] Load weights path and configs # [STEP 1-2] Load weights path and configs
model_weights_path, config, quant_config = self._load_config( model_weights_path, config, quant_config = self._load_config(
...@@ -153,36 +166,21 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -153,36 +166,21 @@ class BaseAWQForCausalLM(nn.Module):
model.tie_weights() model.tie_weights()
# Get device map # loads the weights into modules and distributes
device_map = infer_auto_device_map( # across available devices automatically
model, load_checkpoint_and_dispatch(
no_split_module_classes=[self.layer_type],
max_memory=max_memory,
dtype=torch_dtype
)
# Load checkpoint
load_checkpoint_in_model(
model, model,
checkpoint=model_weights_path, checkpoint=model_weights_path,
device_map=device_map, device_map=device_map,
no_split_module_classes=[self.layer_type],
offload_folder=offload_folder, offload_folder=offload_folder,
dtype=torch_dtype dtype=torch_dtype,
) )
# Dispath to devices # Dispath to devices
if fuse_layers: if fuse_layers:
self.fuse_layers(model) self.fuse_layers(model)
# Offloading dispatch
from accelerate import dispatch_model
model = dispatch_model(
model,
device_map=device_map,
offload_dir=offload_folder
)
return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config) return self(model, model_type, is_quantized=is_quantized, quant_config=quant_config)
def _load_config(self, model_path, model_filename, safetensors=True, def _load_config(self, model_path, model_filename, safetensors=True,
......
...@@ -19,6 +19,7 @@ class LlamaLikeBlock(nn.Module): ...@@ -19,6 +19,7 @@ class LlamaLikeBlock(nn.Module):
).to(dev) ).to(dev)
self.norm_2 = norm_2.to(dev) self.norm_2 = norm_2.to(dev)
self.mlp = mlp.to(dev) self.mlp = mlp.to(dev)
self.device = dev
def forward( def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
...@@ -30,7 +31,7 @@ class LlamaLikeBlock(nn.Module): ...@@ -30,7 +31,7 @@ class LlamaLikeBlock(nn.Module):
attention_mask=attention_mask attention_mask=attention_mask
) )
h = hidden_states + attn_output h = hidden_states.to(attn_output.device) + attn_output
out = h + self.mlp.forward(self.norm_2(h)) out = h + self.mlp.forward(self.norm_2(h))
return out, None, past_key_value return out, None, past_key_value
...@@ -48,6 +49,7 @@ class MPTBlock(nn.Module): ...@@ -48,6 +49,7 @@ class MPTBlock(nn.Module):
).to(dev) ).to(dev)
self.norm_2 = norm_2 self.norm_2 = norm_2
self.ffn = mpt_mlp.to(dev) self.ffn = mpt_mlp.to(dev)
self.device = dev
def forward( def forward(
self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None self, hidden_states, past_key_value, attn_bias=None, attention_mask=None, is_causal=None
...@@ -62,7 +64,7 @@ class MPTBlock(nn.Module): ...@@ -62,7 +64,7 @@ class MPTBlock(nn.Module):
use_cache=True use_cache=True
) )
h = hidden_states + attn_output h = hidden_states.to(attn_output.device) + attn_output
out = h + self.ffn.forward(self.norm_2(h)) out = h + self.ffn.forward(self.norm_2(h))
return out, None, past_key_value return out, None, past_key_value
...@@ -94,6 +96,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -94,6 +96,7 @@ class FalconDecoderLayer(nn.Module):
self.input_layernorm = input_layernorm # before attention self.input_layernorm = input_layernorm # before attention
self.mlp = mlp self.mlp = mlp
self.device = dev
def _get_attention_shapes(self, n_heads, max_seq_len, head_dim): def _get_attention_shapes(self, n_heads, max_seq_len, head_dim):
batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1")) batch_size = int(os.getenv("AWQ_BATCH_SIZE", "1"))
...@@ -136,7 +139,7 @@ class FalconDecoderLayer(nn.Module): ...@@ -136,7 +139,7 @@ class FalconDecoderLayer(nn.Module):
use_cache=True use_cache=True
) )
h_attn = hidden_states + attn_output h_attn = hidden_states.to(attn_output.device) + attn_output
if self.new_decoder_arch: if self.new_decoder_arch:
h_mlp = self.mlp.forward(mlp_layernorm_out) h_mlp = self.mlp.forward(mlp_layernorm_out)
......
import torch import torch
import torch.nn as nn import torch.nn as nn
from typing import List from typing import List
from awq.utils import fused_utils
from transformers.modeling_outputs import BaseModelOutputWithPast from transformers.modeling_outputs import BaseModelOutputWithPast
from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock from awq.modules.fused.block import MPTBlock, FalconDecoderLayer, LlamaLikeBlock
from awq.utils.fused_utils import prepare_attention_mask, prepare_input_ids, prepare_cache
class LlamaLikeModel(nn.Module): class LlamaLikeModel(nn.Module):
""" """
...@@ -20,17 +20,17 @@ class LlamaLikeModel(nn.Module): ...@@ -20,17 +20,17 @@ class LlamaLikeModel(nn.Module):
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(self, input_ids: torch.Tensor, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids( input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, input_ids,
self.last_forward_num_tokens self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen) fused_utils.prepare_cache(self.blocks, seqlen)
h = self.embedding(input_ids) h = self.embedding(input_ids)
mask = prepare_attention_mask( mask = fused_utils.prepare_attention_mask(
seqlen=seqlen, seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos, start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device, device=input_ids.device,
...@@ -38,7 +38,17 @@ class LlamaLikeModel(nn.Module): ...@@ -38,7 +38,17 @@ class LlamaLikeModel(nn.Module):
) )
for layer in self.blocks: for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
)
h = self.norm(h) h = self.norm(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
...@@ -56,17 +66,17 @@ class MPTModel(nn.Module): ...@@ -56,17 +66,17 @@ class MPTModel(nn.Module):
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids( input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, input_ids,
self.last_forward_num_tokens self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen) fused_utils.prepare_cache(self.blocks, seqlen)
h = self.wte(input_ids) h = self.wte(input_ids)
mask = prepare_attention_mask( mask = fused_utils.prepare_attention_mask(
seqlen=seqlen, seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos, start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device, device=input_ids.device,
...@@ -74,7 +84,17 @@ class MPTModel(nn.Module): ...@@ -74,7 +84,17 @@ class MPTModel(nn.Module):
) )
for layer in self.blocks: for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
)
h = self.norm_f(h) h = self.norm_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
...@@ -92,17 +112,17 @@ class FalconModel(nn.Module): ...@@ -92,17 +112,17 @@ class FalconModel(nn.Module):
@torch.inference_mode() @torch.inference_mode()
def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs): def forward(self, input_ids, attn_bias=None, attention_mask=None, is_causal=None, *args, **kwargs):
input_ids, self.last_forward_num_tokens = prepare_input_ids( input_ids, self.last_forward_num_tokens = fused_utils.prepare_input_ids(
input_ids, input_ids,
self.last_forward_num_tokens self.last_forward_num_tokens
) )
_bsz, seqlen = input_ids.shape _bsz, seqlen = input_ids.shape
prepare_cache(self.blocks, seqlen) fused_utils.prepare_cache(self.blocks, seqlen)
h = self.word_embeddings(input_ids) h = self.word_embeddings(input_ids)
mask = prepare_attention_mask( mask = fused_utils.prepare_attention_mask(
seqlen=seqlen, seqlen=seqlen,
start_pos=self.blocks[0].attn.start_pos, start_pos=self.blocks[0].attn.start_pos,
device=input_ids.device, device=input_ids.device,
...@@ -110,7 +130,17 @@ class FalconModel(nn.Module): ...@@ -110,7 +130,17 @@ class FalconModel(nn.Module):
) )
for layer in self.blocks: for layer in self.blocks:
h, _, past_key_value = layer(h, None, attention_mask=mask, is_causal=is_causal) h, mask = fused_utils.prepare_correct_devices(
layer,
h,
mask,
)
h, _, past_key_value = layer(
h,
None,
attention_mask=mask,
is_causal=is_causal
)
h = self.ln_f(h) h = self.ln_f(h)
return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=()) return BaseModelOutputWithPast(last_hidden_state=h, past_key_values=past_key_value, hidden_states=(), attentions=())
...@@ -101,7 +101,7 @@ class WQLinear_GEMM(nn.Module): ...@@ -101,7 +101,7 @@ class WQLinear_GEMM(nn.Module):
input_dtype = x.dtype input_dtype = x.dtype
if input_dtype != torch.float16: if input_dtype != torch.float16:
x = x.half() x = x.half()
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
if input_dtype != torch.float16: if input_dtype != torch.float16:
......
import torch import torch
from typing import List
from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV from awq.modules.linear import WQLinear_GEMM, WQLinear_GEMV
def prepare_correct_devices(next_layer, hidden_states, mask):
hidden_states = hidden_states.to(next_layer.device)
if mask is not None:
mask = mask.to(next_layer.device)
return hidden_states, mask
def prepare_cache(blocks, seqlen: int) -> int: def prepare_cache(blocks, seqlen: int) -> int:
for block in blocks: for block in blocks:
start_pos = block.attn.start_pos start_pos = block.attn.start_pos
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment