Unverified Commit f92ac4df authored by Bartłomiej Kocot's avatar Bartłomiej Kocot Committed by GitHub
Browse files

Merge branch 'develop' into barkocot/universal-backward-data

parents e4b0b07d 6fb1f4e0
...@@ -59,7 +59,7 @@ struct MultiplyMultiply ...@@ -59,7 +59,7 @@ struct MultiplyMultiply
{ {
const float x0_f = c * d0 * d1; const float x0_f = c * d0 * d1;
e = ck::type_convert<ck::bhalf_t>(x0_f); e = ck::type_convert<ck::half_t>(x0_f);
} }
}; };
...@@ -95,7 +95,7 @@ int main(int argc, char* argv[]) ...@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
ck::index_t K = 4096; ck::index_t K = 4096;
ck::index_t StrideA = K; ck::index_t StrideA = K;
ck::index_t StrideB = N; ck::index_t StrideB = K;
ck::index_t StrideD = 0; ck::index_t StrideD = 0;
ck::index_t StrideE = N; ck::index_t StrideE = N;
...@@ -164,10 +164,10 @@ int main(int argc, char* argv[]) ...@@ -164,10 +164,10 @@ int main(int argc, char* argv[])
{ {
case 0: break; case 0: break;
case 1: case 1:
a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-5, 5}); a0_m_k.GenerateTensorValue(GeneratorTensor_2<A0DataType>{-2, 2});
b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{-5, 5}); b0_k_n.GenerateTensorValue(GeneratorTensor_2<B0DataType>{0, 2});
d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{-5, 5}); d0_m_n.GenerateTensorValue(GeneratorTensor_2<D0DataType>{0, 2});
d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{-5, 5}); d1_m_n.GenerateTensorValue(GeneratorTensor_2<D1DataType>{0, 2});
break; break;
default: default:
a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0}); a0_m_k.GenerateTensorValue(GeneratorTensor_3<A0DataType>{0.0, 1.0});
......
...@@ -83,7 +83,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout, ...@@ -83,7 +83,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
static constexpr index_t NumDTensor = DsDataType::Size(); static constexpr index_t NumDTensor = DsDataType::Size();
// GridwiseGemm // GridwiseGemm
using GridwiseGemm = GridwiseGemm_xdl_cshuffle_v3< using GridwiseGemm = GridwiseGemmMultiD_xdl_cshuffle_v3<
ALayout, ALayout,
BLayout, BLayout,
DsLayout, DsLayout,
......
...@@ -146,7 +146,7 @@ template <typename ALayout, ...@@ -146,7 +146,7 @@ template <typename ALayout,
typename ComputeTypeB = ComputeTypeA, typename ComputeTypeB = ComputeTypeA,
typename LDSTypeA = ADataType, typename LDSTypeA = ADataType,
typename LDSTypeB = BDataType> typename LDSTypeB = BDataType>
struct GridwiseGemm_xdl_cshuffle_v3 struct GridwiseGemmMultiD_xdl_cshuffle_v3
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -690,8 +690,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -690,8 +690,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc, a_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(Number<MPerBlock / MLdsLayer>{}, make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<AK0Number * MLdsLayer>{})), Number<MPerBlock / MLdsLayer>{}, Number<AK0Number * MLdsLayer>{})),
make_pass_through_transform(AK1Number)), make_pass_through_transform(AK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{})); make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
...@@ -756,7 +756,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -756,7 +756,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple( make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}), make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}), make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_transform( make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})), make_tuple(Number<KThreadReadPerm * M1>{}, Number<kfold * M0 / mpair>{})),
make_pass_through_transform(Number<mpair>{}), make_pass_through_transform(Number<mpair>{}),
make_pass_through_transform(AK1Number)), make_pass_through_transform(AK1Number)),
...@@ -827,8 +827,8 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -827,8 +827,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc, b_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(Number<NPerBlock / NLdsLayer>{}, make_tuple(make_xor_with_modulo_transform(make_tuple(
Number<BK0Number * NLdsLayer>{})), Number<NPerBlock / NLdsLayer>{}, Number<BK0Number * NLdsLayer>{})),
make_pass_through_transform(BK1Number)), make_pass_through_transform(BK1Number)),
make_tuple(Sequence<1, 0>{}, Sequence<2>{}), make_tuple(Sequence<1, 0>{}, Sequence<2>{}),
make_tuple(Sequence<1, 0>{}, Sequence<2>{})); make_tuple(Sequence<1, 0>{}, Sequence<2>{}));
...@@ -890,7 +890,7 @@ struct GridwiseGemm_xdl_cshuffle_v3 ...@@ -890,7 +890,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple( make_tuple(
make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}), make_pass_through_transform(Number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(Number<K0PerThreadWrite>{}), make_pass_through_transform(Number<K0PerThreadWrite>{}),
make_xor_transform( make_xor_with_modulo_transform(
make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})), make_tuple(Number<KThreadReadPerm * N1>{}, Number<kfold * N0 / npair>{})),
make_pass_through_transform(Number<npair>{}), make_pass_through_transform(Number<npair>{}),
make_pass_through_transform(BK1Number)), make_pass_through_transform(BK1Number)),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment