Unverified Commit ea8bfcb0 authored by oahzxl's avatar oahzxl Committed by GitHub
Browse files

add unit test for layernorm, linear, outproductmean (#82)

* update out product mean test

* polish out product mean test and update linear test

* update test for layernorm, use both kernel

* support test without triton

* polish layernorm

* use current code
parent f2d8aa06
import torch
from fastfold.model.fastnn.ops import Linear as FastLinear
from fastfold.model.nn.primitives import Linear
def test_linear():
c_in = 3
c_out = 4
seq = 5
fast_linear = FastLinear(c_in, c_out).cuda()
linear = Linear(c_in, c_out).cuda()
fast_linear.weight = linear.weight
fast_linear.bias = linear.bias
x = torch.randn((seq, c_in)).cuda()
out1 = fast_linear(x)
out2 = linear(x)
assert torch.allclose(out1, out2, atol=1e-8)
if __name__ == "__main__":
test_linear()
import torch
from fastfold.model.fastnn.kernel import LayerNorm as FastLayerNorm
from fastfold.model.fastnn.kernel.layer_norm import FusedLayerNormAffineFunction
triton = True
try:
from fastfold.model.fastnn.kernel.layer_norm import LayerNormTritonFunc
except:
print("Skip triton layernorm test!")
triton = False
def test_layernorm():
......@@ -19,23 +27,45 @@ def test_layernorm():
dim_ = sample_input.size()[-1]
torch_module = torch.nn.LayerNorm(normalized_shape=dim_).to(device=test_device,
dtype=dtype)
fastnn_module = FastLayerNorm(normalized_shape=dim_).to(device=test_device, dtype=dtype)
fastnn_cuda_module = FastLayerNorm(normalized_shape=dim_).to(device=test_device, dtype=dtype)
if triton:
fastnn_triton_module = FastLayerNorm(normalized_shape=dim_).to(device=test_device, dtype=dtype)
# Forward
torch_out = torch_module(sample_input)
fastnn_out = fastnn_module(sample_input)
forward_error = torch.max(torch.abs(torch_out - fastnn_out)).cpu().item()
fastnn_cuda_out = FusedLayerNormAffineFunction.apply(sample_input, fastnn_cuda_module.weight, fastnn_cuda_module.bias,
fastnn_cuda_module.normalized_shape, fastnn_cuda_module.eps)
forward_error = torch.max(torch.abs(torch_out - fastnn_cuda_out)).cpu().item()
assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
if triton:
fastnn_triton_out = LayerNormTritonFunc.apply(sample_input, fastnn_triton_module.normalized_shape, fastnn_triton_module.weight,
fastnn_triton_module.bias, fastnn_triton_module.eps)
forward_error = torch.max(torch.abs(torch_out - fastnn_triton_out)).cpu().item()
assert forward_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
# Backward
out_grad = torch.rand_like(torch_out).requires_grad_(False)
torch_out.backward(out_grad)
fastnn_out.backward(out_grad)
fastnn_cuda_out.backward(out_grad)
backward_weight_error = torch.max(
torch.abs(torch_module.weight.grad - fastnn_module.weight.grad)).cpu().item()
torch.abs(torch_module.weight.grad - fastnn_cuda_module.weight.grad)).cpu().item()
assert backward_weight_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
backward_bias_error = torch.max(
torch.abs(torch_module.bias.grad - fastnn_module.bias.grad)).cpu().item()
torch.abs(torch_module.bias.grad - fastnn_cuda_module.bias.grad)).cpu().item()
assert backward_bias_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
if triton:
fastnn_triton_out.backward(out_grad)
backward_weight_error = torch.max(
torch.abs(torch_module.weight.grad - fastnn_triton_module.weight.grad)).cpu().item()
assert backward_weight_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
backward_bias_error = torch.max(
torch.abs(torch_module.bias.grad - fastnn_triton_module.bias.grad)).cpu().item()
assert backward_bias_error < tolerance_eps[dtype], f"Error when {shape} {dtype}"
if __name__ == "__main__":
test_layernorm()
import torch
import fastfold
from fastfold.model.fastnn.ops import OutProductMean as FastOutProductMean, set_chunk_size
from fastfold.model.nn.outer_product_mean import OuterProductMean
def test_out_product_mean():
fastfold.distributed.init_dap()
msa_len = 20
seq_len = 30
dim_m = 32
dim_z = 64
hidden = 16
fast_opm = FastOutProductMean(n_feat=dim_m, n_feat_out=dim_z, n_feat_proj=hidden).cuda()
opm = OuterProductMean(c_m=dim_m, c_z=dim_z, c_hidden=hidden).cuda()
fast_opm.linear_a.weight = opm.linear_1.weight
fast_opm.linear_a.bias = opm.linear_1.bias
fast_opm.linear_b.weight = opm.linear_2.weight
fast_opm.linear_b.bias = opm.linear_2.bias
fast_opm.o_linear.weight = opm.linear_out.weight
fast_opm.o_linear.bias = opm.linear_out.bias
m = torch.randn((1, msa_len, seq_len, dim_m)).cuda()
m_mask = torch.ones((1, msa_len, seq_len)).cuda()
m_mask[:, :, -5:] = 0
z = torch.zeros((1, seq_len, seq_len, dim_z)).cuda()
out = fast_opm(m, m_mask, z)
out_fast = opm(m, m_mask)
assert torch.allclose(out, out_fast, atol=1e-6)
set_chunk_size(1)
out_fast = opm(m, m_mask)
assert torch.allclose(out, out_fast, atol=1e-6)
out_fast = fast_opm.inplace(m, m_mask, [z])[0]
assert torch.allclose(out, out_fast, atol=1e-6)
if __name__ == "__main__":
test_out_product_mean()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment