Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
dc0bae32
Commit
dc0bae32
authored
Feb 01, 2023
by
Adam Osewski
Browse files
Merge branch 'develop' into aosewski/wavelet_omniperf
parents
68474822
ba40c2ce
Changes
474
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
504 additions
and
24 deletions
+504
-24
test/gemm_split_k/gemm_split_k.cpp
test/gemm_split_k/gemm_split_k.cpp
+5
-7
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
+4
-4
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
+1
-1
test/grouped_gemm/grouped_gemm_fp16.cpp
test/grouped_gemm/grouped_gemm_fp16.cpp
+9
-5
test/normalization/test_groupnorm_fp16.cpp
test/normalization/test_groupnorm_fp16.cpp
+1
-1
test/normalization/test_groupnorm_fp32.cpp
test/normalization/test_groupnorm_fp32.cpp
+1
-1
test/normalization/test_layernorm2d_fp16.cpp
test/normalization/test_layernorm2d_fp16.cpp
+1
-1
test/normalization/test_layernorm2d_fp32.cpp
test/normalization/test_layernorm2d_fp32.cpp
+1
-1
test/reduce/reduce_no_index.cpp
test/reduce/reduce_no_index.cpp
+1
-1
test/reduce/reduce_with_index.cpp
test/reduce/reduce_with_index.cpp
+1
-1
test/softmax/test_softmax_util.hpp
test/softmax/test_softmax_util.hpp
+1
-1
test/wmma_op/CMakeLists.txt
test/wmma_op/CMakeLists.txt
+2
-0
test/wmma_op/wmma_op.cpp
test/wmma_op/wmma_op.cpp
+67
-0
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+409
-0
No files found.
test/gemm_split_k/gemm_split_k.cpp
View file @
dc0bae32
...
...
@@ -226,9 +226,8 @@ int main(int argc, char* argv[])
std
::
vector
<
gemmArgs
>
test_cases
;
if
(
argc
==
1
)
{
test_cases
=
{{
GemmMatrixLayout
::
MK_KN_MN
,
3
,
3
,
3
,
3
,
3
,
3
,
1
}};
// JD: Populate with more and meaningful
return
0
;
test_cases
=
{{
GemmMatrixLayout
::
MK_KN_MN
,
1024
,
1024
,
1024
,
1024
,
1024
,
1024
,
2
},
{
GemmMatrixLayout
::
MK_KN_MN
,
1024
,
1024
,
1024
,
1024
,
1024
,
1024
,
8
}};
}
else
if
(
argc
==
9
)
{
...
...
@@ -253,11 +252,10 @@ int main(int argc, char* argv[])
printf
(
"arg2 to 7: M, N, K, StrideA, StrideB, StrideC KBatch
\n
"
);
return
-
1
;
}
bool
error
=
false
;
for
(
const
auto
&
kinder
:
test_cases
)
{
const
auto
res
=
test_gemm
(
kinder
);
if
(
!
res
)
return
-
1
;
error
|=
test_gemm
(
kinder
);
}
return
0
;
return
error
?
1
:
0
;
}
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
View file @
dc0bae32
...
...
@@ -9,7 +9,7 @@
#include <gtest/gtest.h>
#include "profiler/
include/
profile_grouped_conv_bwd_weight_impl.hpp"
#include "profiler/profile_grouped_conv_bwd_weight_impl.hpp"
template
<
typename
Tuple
>
class
TestGroupedConvndBwdWeight
:
public
::
testing
::
Test
...
...
@@ -61,7 +61,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test1D)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
4
,
64
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
template
Run
<
1
>();
}
...
...
@@ -72,7 +72,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D)
this
->
conv_params
.
push_back
(
{
2
,
4
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
4
,
32
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
{
2
,
4
,
8
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
4
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
template
Run
<
2
>();
...
...
@@ -84,7 +84,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D)
this
->
conv_params
.
push_back
(
{
3
,
4
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
4
,
32
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
4
,
8
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
4
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>();
...
...
test/grouped_convnd_fwd/grouped_convnd_fwd.cpp
View file @
dc0bae32
...
...
@@ -7,7 +7,7 @@
#include <vector>
#include <gtest/gtest.h>
#include "profiler/
include/
profile_grouped_conv_fwd_impl.hpp"
#include "profiler/profile_grouped_conv_fwd_impl.hpp"
class
TestGroupedConvNdFwd
:
public
::
testing
::
Test
{
...
...
test/grouped_gemm/grouped_gemm_fp16.cpp
View file @
dc0bae32
...
...
@@ -2,8 +2,9 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <iostream>
#include <random>
#include "profiler/
include/
profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp"
namespace
{
...
...
@@ -18,7 +19,10 @@ using Col = ck::tensor_layout::gemm::ColumnMajor;
template
<
typename
ALayout
,
typename
BLayout
,
typename
CLayout
>
bool
TestGroupedGemm
()
{
int
group_count
=
rand
()
%
10
+
1
;
std
::
mt19937
gen
(
19391
);
std
::
uniform_int_distribution
<>
distrib
(
1
,
10
);
int
group_count
=
distrib
(
gen
);
// GEMM shape
std
::
vector
<
ck
::
tensor_operation
::
device
::
GemmDesc
>
gemm_descs
;
...
...
@@ -29,9 +33,9 @@ bool TestGroupedGemm()
for
(
int
i
=
0
;
i
<
group_count
;
i
++
)
{
Ms
.
push_back
(
256
+
256
*
(
rand
()
%
10
));
Ns
.
push_back
(
256
+
256
*
(
rand
()
%
10
));
Ks
.
push_back
(
128
+
128
*
(
rand
()
%
10
));
Ms
.
push_back
(
256
+
256
*
distrib
(
gen
));
Ns
.
push_back
(
256
+
256
*
distrib
(
gen
));
Ks
.
push_back
(
128
+
128
*
distrib
(
gen
));
StrideAs
.
push_back
(
std
::
is_same
<
Row
,
ALayout
>::
value
?
Ks
[
i
]
:
Ms
[
i
]);
StrideBs
.
push_back
(
std
::
is_same
<
Row
,
BLayout
>::
value
?
Ns
[
i
]
:
Ks
[
i
]);
...
...
test/normalization/test_groupnorm_fp16.cpp
View file @
dc0bae32
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/
include/
profile_groupnorm_impl.hpp"
#include "profiler/profile_groupnorm_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
test/normalization/test_groupnorm_fp32.cpp
View file @
dc0bae32
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/
include/
profile_groupnorm_impl.hpp"
#include "profiler/profile_groupnorm_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
test/normalization/test_layernorm2d_fp16.cpp
View file @
dc0bae32
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/
include/
profile_layernorm_impl.hpp"
#include "profiler/profile_layernorm_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
test/normalization/test_layernorm2d_fp32.cpp
View file @
dc0bae32
...
...
@@ -2,7 +2,7 @@
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "profiler/
include/
profile_layernorm_impl.hpp"
#include "profiler/profile_layernorm_impl.hpp"
using
F16
=
ck
::
half_t
;
using
F32
=
float
;
...
...
test/reduce/reduce_no_index.cpp
View file @
dc0bae32
...
...
@@ -4,7 +4,7 @@
#include <getopt.h>
#include "ck/library/utility/host_common_util.hpp"
#include "profiler/
include/
profile_reduce_impl.hpp"
#include "profiler/profile_reduce_impl.hpp"
using
namespace
ck
;
...
...
test/reduce/reduce_with_index.cpp
View file @
dc0bae32
...
...
@@ -4,7 +4,7 @@
#include <getopt.h>
#include "ck/library/utility/host_common_util.hpp"
#include "profiler/
include/
profile_reduce_impl.hpp"
#include "profiler/profile_reduce_impl.hpp"
using
namespace
ck
;
...
...
test/softmax/test_softmax_util.hpp
View file @
dc0bae32
...
...
@@ -13,7 +13,7 @@
#include "ck/tensor_operation/gpu/device/impl/device_softmax_impl.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "include/ck/utility/data_type.hpp"
#include "profiler/
include/
profile_softmax_impl.hpp"
#include "profiler/profile_softmax_impl.hpp"
namespace
ck
{
...
...
test/wmma_op/CMakeLists.txt
0 → 100644
View file @
dc0bae32
add_test_executable
(
test_wmma_op wmma_op.cpp
)
target_link_libraries
(
test_wmma_op PRIVATE utility
)
test/wmma_op/wmma_op.cpp
0 → 100644
View file @
dc0bae32
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "test/wmma_op/wmma_op_util.hpp"
template
<
typename
SrcType
,
typename
DstType
,
typename
GPUAccType
,
typename
CPUAccType
,
ck
::
index_t
AccNum
>
bool
run_test
()
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
bool
pass
=
true
;
const
auto
matmul_default
=
ck
::
wmma_op_util
::
matmul
<
SrcType
,
DstType
,
GPUAccType
,
AccNum
>
;
const
auto
matmul_swizzle_a
=
ck
::
wmma_op_util
::
matmul_swizzle_a
<
SrcType
,
DstType
,
GPUAccType
,
AccNum
>
;
const
auto
wmma_kernel_container
=
std
::
make_tuple
(
matmul_default
,
matmul_swizzle_a
);
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
pass
&=
ck
::
wmma_op_util
::
TestWmma
<
decltype
(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
wmma_kernel_container
)),
SrcType
,
SrcType
,
DstType
,
GPUAccType
,
CPUAccType
,
decltype
(
Row
{}),
decltype
(
Col
{}),
decltype
(
Row
{}),
PassThrough
,
PassThrough
,
PassThrough
,
AccNum
>
{}(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
wmma_kernel_container
));
});
return
pass
?
1
:
0
;
}
int
main
(
int
,
char
*
[])
{
bool
pass
=
true
;
// clang-format off
// |SrcType |DstType |GPUAccType |CPUAccType |AccNum
pass
&=
run_test
<
ck
::
half_t
,
ck
::
half_t
,
float
,
float
,
8
>
();
pass
&=
run_test
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
float
,
8
>
();
pass
&=
run_test
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
16
>
();
pass
&=
run_test
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
16
>
();
pass
&=
run_test
<
int8_t
,
int8_t
,
int32_t
,
int32_t
,
8
>
();
// clang-format on
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/wmma_op/wmma_op_util.hpp
0 → 100644
View file @
dc0bae32
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace
ck
{
namespace
wmma_op_util
{
template
<
typename
src_vec
,
typename
acc_vec
>
__device__
void
builtin_wmma_naive_selector
(
const
src_vec
&
,
const
src_vec
&
,
acc_vec
&
)
{
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
half16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>>
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
bhalf16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>>
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_f32_16x16x16_bf16_w32
<
16
,
16
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
half16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
half_t
,
1
,
16
,
true
>>
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
half_t
,
1
,
16
,
true
>&
reg_c
)
{
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
,
0
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
bhalf16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
bhalf_t
,
1
,
16
,
true
>>
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
bhalf_t
,
1
,
16
,
true
>&
reg_c
)
{
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
,
0
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
int8x16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>>
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
,
true
,
true
,
false
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
__device__
void
builtin_wmma_naive_selector
<
int4x16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>>
(
const
int4x16_t
&
reg_a
,
const
int4x16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_i32_16x16x16_iu4_w32
<
16
,
16
,
true
,
true
,
false
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
#endif
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
__shared__
src_t
p_shared
[
16
*
16
*
2
];
const
int
lIdx
=
threadIdx
.
x
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
// 16x16 matrix tile
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
src_vec
a_frag
=
{};
src_vec
b_frag
=
{};
src_vec
a_temp
=
{};
src_vec
b_temp
=
{};
// initialize c fragment to 0
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
acc_vec
c_thread_buf_
;
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const
int
lane
=
lIdx
%
16
;
const
int
lane_lo
=
lIdx
/
2
;
const
int
lane_hi
=
lIdx
%
2
;
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
a_temp
[
ele
]
=
a
[
8
*
lane_hi
+
16
*
lane_lo
+
ele
];
}
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
b_temp
[
ele
]
=
b
[
8
*
lane_hi
+
16
*
lane_lo
+
ele
];
}
__syncthreads
();
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
]
=
a_temp
[
ele
];
}
for
(
int
ele
=
0
;
ele
<
8
;
++
ele
)
{
p_shared
[
8
*
16
*
lane_hi
+
8
*
lane_lo
+
ele
+
16
*
16
]
=
b_temp
[
ele
];
}
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
+
16
*
16
];
}
// follow origin design
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
p_shared
[(
ele
/
8
)
*
16
*
8
+
8
*
lane
+
ele
%
8
];
}
asm
volatile
(
"\
s_waitcnt lgkmcnt(0)
\n
\
s_barrier \
"
::
);
// sync threads, similar to mma_sync
// __syncthreads();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
// since only fp16_fp32 asm wmma implemented for experiment purpose, restrict test case to fp16
// when enable this ck::amd_assembly_wmma_f32_16x16x16_f16_w32(a_frag, b_frag,
// c_thread_buf_.GetVectorTypeReference(Number<0>{}).template AsType<float8_t>()(Number<0>{}));
__syncthreads
();
// wait for results, similar to mma_sync
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
r
=
ele
*
2
+
(
lIdx
/
16
);
// store results from unpacked c_thread_buf_ output
c
[
16
*
r
+
lane
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
ele
*
acc_num
/
8
>
{}]);
});
}
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul_swizzle_a
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
src_vec
a_frag
=
{};
src_vec
b_frag
=
{};
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
acc_vec
c_thread_buf_
;
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
const
int
offset_m
=
(((
lane
&
1
)
<<
3
)
|
(
lane
>>
1
));
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
offset_m
+
ele
];
}
__syncthreads
();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
__syncthreads
();
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
blk
=
lIdx
/
16
;
const
int
r
=
ele
;
c
[
16
*
8
*
blk
+
16
*
r
+
lane
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
ele
*
acc_num
/
8
>
{}]);
});
}
struct
GemmParams
{
GemmParams
()
:
M
(
16
),
N
(
16
),
K
(
16
),
StrideA
(
16
),
StrideB
(
16
),
StrideC
(
16
),
alpha
(
1
),
beta
(
0
)
{}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
float
alpha
;
float
beta
;
};
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
kernel
<<<
1
,
32
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
template
<
typename
DeviceWmma
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GPUAccDataType
,
typename
CPUAccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
CAccNum
>
struct
TestWmma
{
auto
PrepareGemmTensor
(
const
ck
::
wmma_op_util
::
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_n_k
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
using
dataType
=
decltype
(
type
);
tensor
.
GenerateTensorValue
(
GeneratorTensor_2
<
dataType
>
{
-
5
,
5
});
};
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
b_n_k
,
BDataType
{});
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceWmma
&
wmma_kernel
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
ck
::
wmma_op_util
::
GemmParams
params
;
params
.
M
=
16
;
params
.
N
=
16
;
params
.
K
=
16
;
params
.
StrideA
=
16
;
params
.
StrideB
=
16
;
params
.
StrideC
=
16
;
auto
host_tensors
=
PrepareGemmTensor
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
CPUAccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
ck
::
wmma_op_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
bool
is_supported
=
ck
::
wmma_op_util
::
RunDeviceGEMM
(
wmma_kernel
,
a
,
b
,
c_device
);
if
(
is_supported
)
{
// Assert
bool
res
=
false
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
,
"Error: Incorrect results!"
,
0
,
1.0
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"UNSUPPORTED CDataType"
<<
std
::
endl
;
}
return
res
;
}
else
{
return
true
;
}
}
};
}
// namespace wmma_op_util
}
// namespace ck
Prev
1
…
20
21
22
23
24
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment