Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
f92ac4df
Unverified
Commit
f92ac4df
authored
Jun 01, 2024
by
Bartłomiej Kocot
Committed by
GitHub
Jun 01, 2024
Browse files
Merge branch 'develop' into barkocot/universal-backward-data
parents
e4b0b07d
6fb1f4e0
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
14 additions
and
14 deletions
+14
-14
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp
...emm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp
+6
-6
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
...pu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
+1
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
...ration/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
+7
-7
No files found.
example/65_gemm_multiply_multiply/gemm_multiply_multiply_xdl_fp16.cpp
View file @
f92ac4df
...
...
@@ -59,7 +59,7 @@ struct MultiplyMultiply
{
const
float
x0_f
=
c
*
d0
*
d1
;
e
=
ck
::
type_convert
<
ck
::
b
half_t
>
(
x0_f
);
e
=
ck
::
type_convert
<
ck
::
half_t
>
(
x0_f
);
}
};
...
...
@@ -95,7 +95,7 @@ int main(int argc, char* argv[])
ck
::
index_t
K
=
4096
;
ck
::
index_t
StrideA
=
K
;
ck
::
index_t
StrideB
=
N
;
ck
::
index_t
StrideB
=
K
;
ck
::
index_t
StrideD
=
0
;
ck
::
index_t
StrideE
=
N
;
...
...
@@ -164,10 +164,10 @@ int main(int argc, char* argv[])
{
case
0
:
break
;
case
1
:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
5
,
5
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
-
5
,
5
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
-
5
,
5
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
-
5
,
5
});
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
A0DataType
>
{
-
2
,
2
});
b0_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
B0DataType
>
{
0
,
2
});
d0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D0DataType
>
{
0
,
2
});
d1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
D1DataType
>
{
0
,
2
});
break
;
default:
a0_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
A0DataType
>
{
0.0
,
1.0
});
...
...
include/ck/tensor_operation/gpu/device/impl/device_gemm_multiple_d_xdl_cshuffle_v3.hpp
View file @
f92ac4df
...
...
@@ -83,7 +83,7 @@ struct DeviceGemmMultiD_Xdl_CShuffle_V3 : public DeviceGemmMultipleD<ALayout,
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
// GridwiseGemm
using
GridwiseGemm
=
GridwiseGemm_xdl_cshuffle_v3
<
using
GridwiseGemm
=
GridwiseGemm
MultiD
_xdl_cshuffle_v3
<
ALayout
,
BLayout
,
DsLayout
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3_multi_d.hpp
View file @
f92ac4df
...
...
@@ -146,7 +146,7 @@ template <typename ALayout,
typename
ComputeTypeB
=
ComputeTypeA
,
typename
LDSTypeA
=
ADataType
,
typename
LDSTypeB
=
BDataType
>
struct
GridwiseGemm_xdl_cshuffle_v3
struct
GridwiseGemm
MultiD
_xdl_cshuffle_v3
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
...
@@ -690,8 +690,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
a_lds_block_desc_permuted
=
transform_tensor_descriptor
(
a_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
MPerBlock
/
MLdsLayer
>
{},
Number
<
AK0Number
*
MLdsLayer
>
{})),
make_pass_through_transform
(
AK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
...
@@ -756,7 +756,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
M1
>
{},
Number
<
kfold
*
M0
/
mpair
>
{})),
make_pass_through_transform
(
Number
<
mpair
>
{}),
make_pass_through_transform
(
AK1Number
)),
...
...
@@ -827,8 +827,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_xor_transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_tuple
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
NPerBlock
/
NLdsLayer
>
{},
Number
<
BK0Number
*
NLdsLayer
>
{})),
make_pass_through_transform
(
BK1Number
)),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
1
,
0
>
{},
Sequence
<
2
>
{}));
...
...
@@ -890,7 +890,7 @@ struct GridwiseGemm_xdl_cshuffle_v3
make_tuple
(
make_pass_through_transform
(
Number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
Number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_xor_
with_modulo_
transform
(
make_tuple
(
Number
<
KThreadReadPerm
*
N1
>
{},
Number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
Number
<
npair
>
{}),
make_pass_through_transform
(
BK1Number
)),
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment