"src/vscode:/vscode.git/clone" did not exist on "1e5eaca754bce676ce9142cab7ccaaee78df4696"
Commit 47ab20a9 authored by Casper Hansen's avatar Casper Hansen
Browse files

Add support for rmsnorm through class name

parent 0090ad81
...@@ -193,12 +193,16 @@ def apply_scale(module, scales_list, input_feat_dict=None): ...@@ -193,12 +193,16 @@ def apply_scale(module, scales_list, input_feat_dict=None):
if isinstance(prev_op, nn.Linear): if isinstance(prev_op, nn.Linear):
assert len(layers) == 1 assert len(layers) == 1
scale_fc_fc(prev_op, layers[0], scales) scale_fc_fc(prev_op, layers[0], scales)
elif isinstance(prev_op, (nn.LayerNorm, LlamaRMSNorm)):
elif any(isinstance(prev_op,t) for t in [nn.LayerNorm, LlamaRMSNorm]) \
or 'rmsnorm' in str(prev_op.__class__).lower():
scale_ln_fcs(prev_op, layers, scales) scale_ln_fcs(prev_op, layers, scales)
elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]): elif any(isinstance(prev_op,t) for t in [nn.GELU, BloomGelu, NewGELUActivation]):
new_module = ScaledActivation(prev_op, scales) new_module = ScaledActivation(prev_op, scales)
set_op_by_name(module, prev_op_name, new_module) set_op_by_name(module, prev_op_name, new_module)
scale_gelu_fc(prev_op, layers[0], scales) scale_gelu_fc(prev_op, layers[0], scales)
else: else:
raise NotImplementedError( raise NotImplementedError(
f"prev_op {type(prev_op)} not supported yet!") f"prev_op {type(prev_op)} not supported yet!")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment