Commit 9caf30ae authored by Jennifer Wei's avatar Jennifer Wei
Browse files

Change casting for deepspeed compare model test to fp32

parent 0c2d455e
......@@ -315,8 +315,7 @@ class TestDeepSpeedKernel(unittest.TestCase):
# Move the recycling dimension to the end
move_dim = lambda t: t.permute(*range(len(t.shape))[1:], 0)
batch = tensor_tree_map(move_dim, batch)
with torch.no_grad():
with torch.cuda.amp.autocast(dtype=torch.bfloat16):
with torch.no_grad(), torch.cuda.amp.autocast(dtype=torch.float32):
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model(batch)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment