Unverified Commit bcf78947 authored by Casper's avatar Casper Committed by GitHub
Browse files

Merge pull request #23 from casper-hansen/yarn

YaRN support for LLaMa models
parents 198ba2fb 47ab20a9
...@@ -321,7 +321,13 @@ class BaseAWQForCausalLM(nn.Module): ...@@ -321,7 +321,13 @@ class BaseAWQForCausalLM(nn.Module):
# Load model weights # Load model weights
model = AutoModelForCausalLM.from_pretrained( model = AutoModelForCausalLM.from_pretrained(
model_filename, device_map=device_map, offload_folder="offload", offload_state_dict=True, torch_dtype=torch_dtype, use_safetensors=safetensors model_filename,
device_map=device_map,
trust_remote_code=trust_remote_code,
offload_folder="offload",
offload_state_dict=True,
torch_dtype=torch_dtype,
use_safetensors=safetensors
) )
model.eval() model.eval()
......
...@@ -193,12 +193,16 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -193,12 +193,16 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if isinstance(prev_op, nn.Linear): if isinstance(prev_op, nn.Linear):
assert len(layers) == 1 assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
elif any(isinstance(prev_op,t) for t in [nn.LayerNorm, LlamaRMSNorm]) \
or 'rmsnorm' in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]): elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales) new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module) set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales) scale_gelu_fc(prev_op, layers[0], scales)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!") f"prev_op {type(prev_op)} not supported yet!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment