Unverified Commit 0fbfc4b8 authored by CHU Tianxiang's avatar CHU Tianxiang Committed by GitHub
Browse files

Add GPTQ support (#916)

parent c06170cc
...@@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module): ...@@ -377,6 +377,9 @@ class ChatGLMForCausalLM(nn.Module):
continue continue
if "word_embeddings" in name: if "word_embeddings" in name:
name = name.replace(".word_embeddings", "") name = name.replace(".word_embeddings", "")
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -425,27 +425,32 @@ class FalconForCausalLM(nn.Module): ...@@ -425,27 +425,32 @@ class FalconForCausalLM(nn.Module):
params_dict = dict(self.named_parameters()) params_dict = dict(self.named_parameters())
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
if "query_key_value" in name: if "query_key_value" in name:
output_dim = getattr(param, "output_dim", None) output_dim = getattr(param, "output_dim", None)
loaded_weight_shape = loaded_weight.shape loaded_weight_shape = loaded_weight.shape
loaded_weight = loaded_weight.view( if output_dim is not None:
loaded_weight_shape[:output_dim] + loaded_weight = loaded_weight.view(
(total_num_kv_heads, num_query_heads_per_kv_head + 2, -1) + loaded_weight_shape[:output_dim] +
loaded_weight_shape[output_dim + 1:]) (total_num_kv_heads, num_query_heads_per_kv_head + 2,
wq = loaded_weight.narrow( -1) + loaded_weight_shape[output_dim + 1:])
output_dim + 1, 0, num_query_heads_per_kv_head).reshape( wq = loaded_weight.narrow(
*loaded_weight_shape[:output_dim], -1, output_dim + 1, 0,
*loaded_weight_shape[output_dim + 1:]) num_query_heads_per_kv_head).reshape(
wk = loaded_weight.narrow( *loaded_weight_shape[:output_dim], -1,
output_dim + 1, num_query_heads_per_kv_head, *loaded_weight_shape[output_dim + 1:])
1).reshape(*loaded_weight_shape[:output_dim], -1, wk = loaded_weight.narrow(
*loaded_weight_shape[output_dim + 1:]) output_dim + 1, num_query_heads_per_kv_head,
wv = loaded_weight.narrow( 1).reshape(*loaded_weight_shape[:output_dim], -1,
output_dim + 1, num_query_heads_per_kv_head + 1, *loaded_weight_shape[output_dim + 1:])
1).reshape(*loaded_weight_shape[:output_dim], -1, wv = loaded_weight.narrow(
*loaded_weight_shape[output_dim + 1:]) output_dim + 1, num_query_heads_per_kv_head + 1,
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim) 1).reshape(*loaded_weight_shape[:output_dim], -1,
*loaded_weight_shape[output_dim + 1:])
loaded_weight = torch.cat([wq, wk, wv], dim=output_dim)
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module): ...@@ -275,7 +275,6 @@ class GPT2LMHeadModel(nn.Module):
if not name.endswith(".weight"): if not name.endswith(".weight"):
continue continue
loaded_weight = loaded_weight.t() loaded_weight = loaded_weight.t()
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
weight_loader(param, loaded_weight) weight_loader(param, loaded_weight)
...@@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module): ...@@ -274,11 +274,18 @@ class GPTJForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module): ...@@ -72,7 +72,6 @@ class GPTNeoXAttention(nn.Module):
config.hidden_size, config.hidden_size,
linear_method=linear_method, linear_method=linear_method,
) )
scaling = self.head_size**-0.5 scaling = self.head_size**-0.5
rotary_dim = int(self.head_size * config.rotary_pct) rotary_dim = int(self.head_size * config.rotary_pct)
assert rotary_dim % 2 == 0 assert rotary_dim % 2 == 0
......
...@@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module): ...@@ -289,11 +289,18 @@ class InternLMForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module): ...@@ -330,11 +330,18 @@ class LlamaForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module): ...@@ -321,11 +321,18 @@ class MistralForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -153,7 +153,7 @@ class MixtralMoE(nn.Module): ...@@ -153,7 +153,7 @@ class MixtralMoE(nn.Module):
self.gate = ReplicatedLinear(config.hidden_size, self.gate = ReplicatedLinear(config.hidden_size,
self.num_total_experts, self.num_total_experts,
bias=False, bias=False,
linear_method=linear_method) linear_method=None)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape batch_size, sequence_length, hidden_dim = hidden_states.shape
...@@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module): ...@@ -418,11 +418,18 @@ class MixtralForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module): ...@@ -297,6 +297,9 @@ class MPTForCausalLM(nn.Module):
params_dict = dict(self.named_parameters(remove_duplicate=False)) params_dict = dict(self.named_parameters(remove_duplicate=False))
for name, loaded_weight in hf_model_weights_iterator( for name, loaded_weight in hf_model_weights_iterator(
model_name_or_path, cache_dir, load_format, revision): model_name_or_path, cache_dir, load_format, revision):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module): ...@@ -345,11 +345,18 @@ class OPTForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module): ...@@ -305,6 +305,9 @@ class PhiForCausalLM(nn.Module):
if "rotary_emb.inv_freq" in name: if "rotary_emb.inv_freq" in name:
continue continue
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
# pylint: disable=E1136 # pylint: disable=E1136
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
......
...@@ -82,7 +82,6 @@ class QWenAttention(nn.Module): ...@@ -82,7 +82,6 @@ class QWenAttention(nn.Module):
self.num_heads = (self.total_num_heads // self.num_heads = (self.total_num_heads //
tensor_model_parallel_world_size) tensor_model_parallel_world_size)
self.head_dim = hidden_size // self.total_num_heads self.head_dim = hidden_size // self.total_num_heads
self.c_attn = QKVParallelLinear( self.c_attn = QKVParallelLinear(
hidden_size, hidden_size,
self.head_dim, self.head_dim,
...@@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module): ...@@ -279,11 +278,18 @@ class QWenLMHeadModel(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module): ...@@ -320,11 +320,18 @@ class YiForCausalLM(nn.Module):
for (param_name, weight_name, shard_id) in stacked_params_mapping: for (param_name, weight_name, shard_id) in stacked_params_mapping:
if weight_name not in name: if weight_name not in name:
continue continue
param = params_dict[name.replace(weight_name, param_name)] name = name.replace(weight_name, param_name)
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id) weight_loader(param, loaded_weight, shard_id)
break break
else: else:
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
param = params_dict[name] param = params_dict[name]
weight_loader = getattr(param, "weight_loader", weight_loader = getattr(param, "weight_loader",
default_weight_loader) default_weight_loader)
......
...@@ -287,4 +287,5 @@ def initialize_dummy_weights( ...@@ -287,4 +287,5 @@ def initialize_dummy_weights(
values between -1e-3 and 1e-3 works well for most models. values between -1e-3 and 1e-3 works well for most models.
""" """
for param in model.state_dict().values(): for param in model.state_dict().values():
param.data.uniform_(low, high) if torch.is_floating_point(param):
param.data.uniform_(low, high)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment