Commit 11b83234 authored by Chao Liu's avatar Chao Liu
Browse files

clean up

parent 8e3aef3b
...@@ -284,8 +284,7 @@ struct ThreadwiseTensorSliceTransfer_v3r3 ...@@ -284,8 +284,7 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// TODO make this logic more generic for more sub-dword datatype // TODO make this logic more generic for more sub-dword datatype
if constexpr(SrcVectorDim != DstVectorDim && if constexpr(SrcVectorDim != DstVectorDim &&
is_same<half_t, remove_cvref_t<SrcData>>::value && is_same<half_t, remove_cvref_t<SrcData>>::value &&
(is_same<half_t, remove_cvref_t<DstData>>::value || is_same<half_t, remove_cvref_t<DstData>>::value &&
is_same<bhalf_t, remove_cvref_t<DstData>>::value) &&
SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0) SrcScalarPerVector % 2 == 0 && DstScalarPerVector % 2 == 0)
{ {
// each transpose does // each transpose does
...@@ -344,27 +343,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3 ...@@ -344,27 +343,8 @@ struct ThreadwiseTensorSliceTransfer_v3r3
// do data transpose // do data transpose
// TODO type_convert is not used yet!!!!! // TODO type_convert is not used yet!!!!!
transpose_convert_vectors<SrcData, transpose_vectors<SrcData, DstScalarPerVector, SrcScalarPerVector>{}(
DstData, src_vector_refs, dst_vector_refs);
DstScalarPerVector,
SrcScalarPerVector>{}(src_vector_refs, dst_vector_refs);
});
}
else if constexpr(SrcVectorDim == DstVectorDim && SrcScalarPerVector % 2 == 0 &&
DstScalarPerVector % 2 == 0 &&
is_same<half_t, remove_cvref_t<SrcData>>::value &&
is_same<bhalf_t, remove_cvref_t<DstData>>::value)
{
auto NewSliceLengths = SliceLengths{}.template Modify(
Number<SrcVectorDim>{}, Number<SliceLengths{}[SrcVectorDim] / 2>{});
auto VectorStep = SliceLengths{} / NewSliceLengths;
static_ford<decltype(NewSliceLengths)>{}([&](auto idx) {
// convert from SrcData to DstData here
auto nidx = idx * VectorStep;
auto vhalf =
src_thread_scratch_tuple_[thread_scratch_id].template GetAsType<half2_t>(nidx);
dst_thread_scratch_.template SetAsType<bhalf2_t>(nidx,
type_convert<bhalf2_t>(vhalf));
}); });
} }
else else
......
...@@ -59,4 +59,4 @@ add_subdirectory(batched_gemm_reduce) ...@@ -59,4 +59,4 @@ add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd) add_subdirectory(convnd_fwd)
add_subdirectory(reduce) add_subdirectory(reduce)
# add_subdirectory(conv2d_bwd_weight) add_subdirectory(conv2d_bwd_weight)
...@@ -33,10 +33,10 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo ...@@ -33,10 +33,10 @@ void add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNo
void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
// void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
// void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
// void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
// void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
...@@ -63,8 +63,8 @@ int main() ...@@ -63,8 +63,8 @@ int main()
std::vector<DeviceGemmNoOpPtr> gemmPtrs; std::vector<DeviceGemmNoOpPtr> gemmPtrs;
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
// ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_splitk_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_kn_mn_instances(gemmPtrs);
...@@ -85,8 +85,8 @@ int main() ...@@ -85,8 +85,8 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
// ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_splitk_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_km_nk_mn_instances(gemmPtrs);
...@@ -107,8 +107,8 @@ int main() ...@@ -107,8 +107,8 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
// ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_splitk_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_kn_mn_instances(gemmPtrs);
...@@ -129,8 +129,8 @@ int main() ...@@ -129,8 +129,8 @@ int main()
gemmPtrs.clear(); gemmPtrs.clear();
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
// ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
// add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_splitk_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs); add_device_gemm_xdl_c_shuffle_f16_f16_f16_mk_nk_mn_instances(gemmPtrs);
ck::tensor_operation::device::device_gemm_instance:: ck::tensor_operation::device::device_gemm_instance::
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment