rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.weight.grad,'(three o) i -> three o i',three=3)[:,rank*partition_dim:(rank+1)*partition_dim],'three o i -> (three o) i'),
rtol=rtol,atol=atol*10
)
asserttorch.allclose(
model.transformer.layers[i].mixer.Wqkv.bias.grad,
rearrange(rearrange(model_pt.transformer.layers[i].mixer.Wqkv.bias.grad,'(three o) -> three o',three=3)[:,rank*partition_dim:(rank+1)*partition_dim],