Commit 44470a7a authored by Casper's avatar Casper
Browse files

Added type hinting

parent 45c22ee5
......@@ -63,7 +63,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
scales.cpu()
@torch.no_grad()
def scale_ln_fcs(ln, fcs, scales):
def scale_ln_fcs(ln: nn.Linear, fcs: list[nn.Linear], scales: torch.Tensor):
if not isinstance(fcs, list):
fcs = [fcs]
......@@ -83,7 +83,7 @@ def scale_ln_fcs(ln, fcs, scales):
assert torch.isnan(p).sum() == 0
@torch.no_grad()
def scale_fc_fc(fc1, fc2, scales):
def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
assert isinstance(fc1, nn.Linear)
assert isinstance(fc2, nn.Linear)
......@@ -102,7 +102,7 @@ def scale_fc_fc(fc1, fc2, scales):
@torch.no_grad()
def scale_gelu_fc(gelu, fc, scales):
def scale_gelu_fc(gelu: allowed_act_fns, fc: nn.Linear, scales: torch.Tensor):
assert any(isinstance(gelu,t) for t in allowed_act_fns)
assert isinstance(fc, nn.Linear)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment