"src/git@developer.sourcefind.cn:chenpangpang/open-webui.git" did not exist on "e82027310dbf7a6ee6ad54c332c5bc2f012fe0b4"
Commit 52c53f39 authored by thomwolf's avatar thomwolf
Browse files

clean up apex integration

parent 4946c2c5
...@@ -59,9 +59,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor ...@@ -59,9 +59,9 @@ def convert_tf_checkpoint_to_pytorch(tf_checkpoint_path, bert_config_file, pytor
l = re.split(r'_(\d+)', m_name) l = re.split(r'_(\d+)', m_name)
else: else:
l = [m_name] l = [m_name]
if l[0] == 'kernel': if l[0] == 'kernel' or l[0] == 'gamma':
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, 'weight')
elif l[0] == 'output_bias': elif l[0] == 'output_bias' or l[0] == 'beta':
pointer = getattr(pointer, 'bias') pointer = getattr(pointer, 'bias')
elif l[0] == 'output_weights': elif l[0] == 'output_weights':
pointer = getattr(pointer, 'weight') pointer = getattr(pointer, 'weight')
......
...@@ -516,9 +516,9 @@ class PreTrainedBertModel(nn.Module): ...@@ -516,9 +516,9 @@ class PreTrainedBertModel(nn.Module):
for key in state_dict.keys(): for key in state_dict.keys():
new_key = None new_key = None
if 'gamma' in key: if 'gamma' in key:
new_key = key.replace('gamma','weight') new_key = key.replace('gamma', 'weight')
if 'beta' in key: if 'beta' in key:
new_key = key.replace('beta','bias') new_key = key.replace('beta', 'bias')
if new_key: if new_key:
old_keys.append(key) old_keys.append(key)
new_keys.append(new_key) new_keys.append(new_key)
......
...@@ -35,7 +35,7 @@ class OptimizationTest(unittest.TestCase): ...@@ -35,7 +35,7 @@ class OptimizationTest(unittest.TestCase):
criterion = torch.nn.MSELoss() criterion = torch.nn.MSELoss()
# No warmup, constant schedule, no gradient clipping # No warmup, constant schedule, no gradient clipping
optimizer = BertAdam(params=[w], lr=2e-1, optimizer = BertAdam(params=[w], lr=2e-1,
weight_decay=0.0, weight_decay_rate=0.0,
max_grad_norm=-1) max_grad_norm=-1)
for _ in range(100): for _ in range(100):
loss = criterion(w, target) loss = criterion(w, target)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment