"docs/git@developer.sourcefind.cn:change/sglang.git" did not exist on "f8a173bb500afb6f22d2f7045bef354d867b647c"
Commit 4a77f453 authored by ltqin's avatar ltqin
Browse files

add test for fp64

parent 85ef3f28
# device_gemm_instance # device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE set(DEVICE_GEMM_INSTANCE_SOURCE
device_gemm_xdl_f64_f64_f64_mk_kn_mn_instance.cpp;
device_gemm_xdl_f64_f64_f64_mk_nk_mn_instance.cpp;
device_gemm_xdl_f64_f64_f64_km_kn_mn_instance.cpp;
device_gemm_xdl_f64_f64_f64_km_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
......
add_test_executable(test_gemm_fp64 gemm_fp64.cpp)
target_link_libraries(test_gemm_fp64 PRIVATE host_tensor)
target_link_libraries(test_gemm_fp64 PRIVATE device_gemm_instance)
add_test_executable(test_gemm_fp32 gemm_fp32.cpp) add_test_executable(test_gemm_fp32 gemm_fp32.cpp)
target_link_libraries(test_gemm_fp32 PRIVATE host_tensor) target_link_libraries(test_gemm_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance) target_link_libraries(test_gemm_fp32 PRIVATE device_gemm_instance)
......
...@@ -140,12 +140,12 @@ struct TestGemm ...@@ -140,12 +140,12 @@ struct TestGemm
Tensor<CDataType> c_m_n_device_result( Tensor<CDataType> c_m_n_device_result(
f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{})); f_host_tensor_descriptor(params.M, params.N, params.StrideC, CLayout{}));
auto f_generate_tensor_value = [](auto desc, auto type) { auto f_generate_tensor_value = [](auto& desc, auto type) {
using dataType = decltype(type); using dataType = decltype(type);
if(std::is_same<dataType, int8_t>::value) if(std::is_same<dataType, int8_t>::value || std::is_same<dataType, double>::value)
{ {
desc.GenerateTensorValue(GeneratorTensor_2<int8_t>{-5, 5}); desc.GenerateTensorValue(GeneratorTensor_2<dataType>{-5, 5});
} }
else else
{ {
...@@ -161,6 +161,7 @@ struct TestGemm ...@@ -161,6 +161,7 @@ struct TestGemm
auto operator()(DeviceGemmPtr_& gemmPtr) auto operator()(DeviceGemmPtr_& gemmPtr)
{ {
std::cout << "data type: " << typeid(ADataType{}).name() << std::endl;
std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name std::cout << "ALayout = " << ALayout{}.name << ", BLayout = " << BLayout{}.name
<< ", CLayout = " << CLayout{}.name << std::endl; << ", CLayout = " << CLayout{}.name << std::endl;
std::cout << gemmPtr->GetTypeString() << std::endl; std::cout << gemmPtr->GetTypeString() << std::endl;
...@@ -202,7 +203,12 @@ struct TestGemm ...@@ -202,7 +203,12 @@ struct TestGemm
// Assert // Assert
bool res = false; bool res = false;
if(std::is_same<CDataType, float>::value) if(std::is_same<CDataType, double>::value)
{
res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
}
else if(std::is_same<CDataType, float>::value)
{ {
res = ck::utils::check_err(c_device.mData, c_host.mData); res = ck::utils::check_err(c_device.mData, c_host.mData);
std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl; std::cout << (res ? "SUCCESS" : "FAILURE") << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment