Commit 9a07b7f9 authored by Christina Floristean's avatar Christina Floristean
Browse files

Fix deepspeed test for multimer

parent 9e057b7a
......@@ -193,13 +193,15 @@ class Linear(nn.Linear):
)
if self.precision is not None:
with torch.cuda.amp.autocast(enabled=False):
bias = self.bias.to(dtype=self.precision) if self.bias is not None else None
return nn.functional.linear(input.to(dtype=self.precision),
self.weight.to(dtype=self.precision),
self.bias.to(dtype=self.precision)).to(dtype=d)
bias).to(dtype=d)
if d is torch.bfloat16 and not deepspeed_is_initialized:
with torch.cuda.amp.autocast(enabled=False):
return nn.functional.linear(input, self.weight.to(dtype=d), self.bias.to(dtype=d))
bias = self.bias.to(dtype=d) if self.bias is not None else None
return nn.functional.linear(input, self.weight.to(dtype=d), bias)
return nn.functional.linear(input, self.weight, self.bias)
......
......@@ -236,19 +236,28 @@ class TestDeepSpeedKernel(unittest.TestCase):
n_res = 20
eps = 2e-2
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
batch = random_template_feats(n_templ, n_res)
batch["template_all_atom_masks"] = batch["template_all_atom_mask"]
if consts.is_multimer:
batch["asym_id"] = batch['asym_id'][0]
pair_act = np.random.rand(n_res, n_res, consts.c_z).astype(np.float32)
pair_mask = np.random.randint(0, 2, (n_res, n_res)).astype(np.float32)
inds = np.random.randint(0, 21, (n_res,))
batch["target_feat"] = np.eye(22)[inds]
batch = {k: torch.as_tensor(v).cuda() for k, v in batch.items()}
template_feats = {
k: v for k, v in batch.items() if k.startswith("template_")
}
with torch.no_grad():
model = compare_utils.get_global_pretrained_openfold()
model.globals.use_deepspeed_evo_attention = False
out_repro = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
template_feats,
batch,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
......@@ -258,7 +267,8 @@ class TestDeepSpeedKernel(unittest.TestCase):
model.globals.use_deepspeed_evo_attention = True
out_repro_ds = model.embed_templates(
{k: torch.as_tensor(v).cuda() for k, v in batch.items()},
template_feats,
batch,
torch.as_tensor(pair_act).cuda(),
torch.as_tensor(pair_mask).cuda(),
templ_dim=0,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment