-
ndickson-nvidia authored
* * Added specialization of cublasGemm function for `__half` type, to try to address https://github.com/dmlc/dgl/issues/3988 * * Added USE_FP16 guard * * Added test cases to test_segment_mm, to test newly-added FP16 specialization of cublasGemm * * Replaced for loop in test_segment_mm with pytest.mark.parametrize, as recommended Co-authored-by:
Xin Yao <xiny@nvidia.com>
eabcc58e