"...composable_kernel_rocm.git" did not exist on "b79c7afb737a433829065d5b142bccb4c5c19893"
Commit eff586ac authored by qinletao's avatar qinletao
Browse files

fix ifdef

parent 873d0958
...@@ -50,7 +50,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl ...@@ -50,7 +50,7 @@ using DeviceGemmInstance = ck::tensor_operation::device::DeviceGemmXdl
//##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar| //##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| SrcDstVectorDim| DstScalar|
//##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector| //##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | | PerVector|
//##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | //##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
< F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 2, 2, true, 7, 1>; < F64, F64, F64, F64, Row, Col, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 4, 2, 16, 16, 4, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 1, 1, true, 7, 1>;
// clang-format on // clang-format on
using ReferenceGemmInstance = ck::tensor_operation::host:: using ReferenceGemmInstance = ck::tensor_operation::host::
...@@ -135,8 +135,10 @@ int main(int argc, char* argv[]) ...@@ -135,8 +135,10 @@ int main(int argc, char* argv[])
b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5}); b_k_n.GenerateTensorValue(GeneratorTensor_3<BDataType>{-0.5, 0.5});
break; break;
default: default:
//a_m_k.GenerateTensorValue(GeneratorTensor_2<ADataType>{-5, 5});
b_k_n.GenerateTensorValue(GeneratorTensor_2<BDataType>{-5, 5});
a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1}); a_m_k.GenerateTensorValue(GeneratorTensor_1<ADataType>{1});
b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1}); //b_k_n.GenerateTensorValue(GeneratorTensor_1<BDataType>{1});
} }
DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace()); DeviceMem a_m_k_device_buf(sizeof(ADataType) * a_m_k.mDesc.GetElementSpace());
...@@ -198,13 +200,15 @@ int main(int argc, char* argv[]) ...@@ -198,13 +200,15 @@ int main(int argc, char* argv[])
ref_invoker.Run(ref_argument); ref_invoker.Run(ref_argument);
if(0) #if 1
{ {
LogRangeAsType<double>(std::cout << "a : ", a_m_k.mData, ",") << std::endl; LogRangeAsType<double>(std::cout << "a : ", a_m_k.mData, ",") << std::endl;
LogRangeAsType<double>(std::cout << "b: ", b_k_n.mData, ",") << std::endl; LogRangeAsType<double>(std::cout << "b: ", b_k_n.mData, ",") << std::endl;
LogRangeAsType<double>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") LogRangeAsType<double>(std::cout << "c_device: ", c_m_n_device_result.mData, ",") << std::endl;
LogRangeAsType<double>(std::cout << "c_host: ", c_m_n_host_result.mData, ",")
<< std::endl; << std::endl;
} }
#endif
ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData); ck::utils::check_err(c_m_n_device_result.mData, c_m_n_host_result.mData);
} }
......
...@@ -303,7 +303,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16> ...@@ -303,7 +303,7 @@ struct intrin_mfma_f64_16x16x4f64<16, 16>
template <class FloatC> template <class FloatC>
__device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c) __device__ static void Run(const double& reg_a, const double& reg_b, FloatC& reg_c)
{ {
#ifdef __gxf90a__ #ifdef __gfx90a__
reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64( reg_c.template AsType<double4_t>()(Number<0>{}) = __builtin_amdgcn_mfma_f64_16x16x4f64(
reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0); reg_a, reg_b, reg_c.template AsType<double4_t>()[Number<0>{}], 0, 0, 0);
#else #else
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment