[Bug][Feature] Added cublasGemm<__half> specialization (#3988) (#4029)
* * Added specialization of cublasGemm function for `__half` type, to try to address https://github.com/dmlc/dgl/issues/3988 * * Added USE_FP16 guard * * Added test cases to test_segment_mm, to test newly-added FP16 specialization of cublasGemm * * Replaced for loop in test_segment_mm with pytest.mark.parametrize, as recommended Co-authored-by:Xin Yao <xiny@nvidia.com>
Showing
Please register or sign in to comment