Commit c0daa62e authored by Tri Dao's avatar Tri Dao
Browse files

Add type check (fp16) in the forward pass

parent ea38d3d2
......@@ -130,6 +130,9 @@ mha_fwd(const at::Tensor &qkv, // total x num_heads x 3 x head_size, tot
bool is_dropout = p_dropout > 0.0;
Launch_params<Fused_multihead_attention_fprop_params> launch_params(dprops, stream, is_dropout, return_softmax);
TORCH_CHECK(qkv.dtype() == torch::kFloat16);
TORCH_CHECK(cu_seqlens.dtype() == torch::kInt32);
TORCH_CHECK(qkv.is_cuda())
TORCH_CHECK(cu_seqlens.is_cuda())
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment