"src/vscode:/vscode.git/clone" did not exist on "125d783076e5bd9785beb05367a2d2566843a271"
Commit f7f2dd01 authored by Maggie Li's avatar Maggie Li Committed by Myle Ott
Browse files

Add ensemble for different architectures (#235)

parent 202e0bbe
......@@ -146,17 +146,20 @@ def load_ensemble_for_inference(filenames, task, model_arg_overrides=None):
state = torch.load(filename, map_location=lambda s, l: default_restore_location(s, 'cpu'))
state = _upgrade_state_dict(state)
states.append(state)
args = states[0]['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
# build ensemble
ensemble = []
for state in states:
args = state['args']
if model_arg_overrides is not None:
args = _override_model_args(args, model_arg_overrides)
# build model for ensemble
model = task.build_model(args)
model.upgrade_state_dict(state['model'])
model.load_state_dict(state['model'], strict=True)
ensemble.append(model)
return ensemble, args
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment