Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
e6715976
Commit
e6715976
authored
Dec 15, 2022
by
letaoqin
Browse files
Merge branch 'develop' into dl_conv_multiple_d
parents
ca313a29
10c72ace
Changes
45
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
444 additions
and
3 deletions
+444
-3
test/CMakeLists.txt
test/CMakeLists.txt
+3
-0
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
+3
-3
test/wmma_op/CMakeLists.txt
test/wmma_op/CMakeLists.txt
+2
-0
test/wmma_op/wmma_op.cpp
test/wmma_op/wmma_op.cpp
+67
-0
test/wmma_op/wmma_op_util.hpp
test/wmma_op/wmma_op_util.hpp
+369
-0
No files found.
test/CMakeLists.txt
View file @
e6715976
...
...
@@ -55,3 +55,6 @@ add_subdirectory(normalization)
add_subdirectory
(
data_type
)
add_subdirectory
(
elementwise_normalization
)
add_subdirectory
(
batchnorm
)
if
(
GPU_TARGETS MATCHES
"gfx1100"
)
add_subdirectory
(
wmma_op
)
endif
()
test/grouped_convnd_bwd_weight/grouped_convnd_bwd_weight.cpp
View file @
e6715976
...
...
@@ -61,7 +61,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test1D)
{
this
->
conv_params
.
clear
();
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
1
},
{
14
},
{
2
},
{
1
},
{
0
},
{
0
}});
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
4
,
64
,
128
,
256
,
{
3
},
{
28
},
{
1
},
{
1
},
{
1
},
{
1
}});
this
->
conv_params
.
push_back
({
1
,
4
,
128
,
128
,
256
,
{
1
},
{
3
},
{
1
},
{
1
},
{
0
},
{
0
}});
this
->
template
Run
<
1
>();
}
...
...
@@ -72,7 +72,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test2D)
this
->
conv_params
.
push_back
(
{
2
,
4
,
128
,
128
,
256
,
{
1
,
1
},
{
7
,
7
},
{
2
,
2
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
conv_params
.
push_back
(
{
2
,
4
,
32
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
{
2
,
4
,
8
,
128
,
256
,
{
3
,
3
},
{
14
,
14
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
},
{
1
,
1
}});
this
->
conv_params
.
push_back
(
{
2
,
4
,
128
,
128
,
256
,
{
1
,
1
},
{
3
,
3
},
{
1
,
1
},
{
1
,
1
},
{
0
,
0
},
{
0
,
0
}});
this
->
template
Run
<
2
>();
...
...
@@ -84,7 +84,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight, Test3D)
this
->
conv_params
.
push_back
(
{
3
,
4
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
7
,
7
,
7
},
{
2
,
2
,
2
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
conv_params
.
push_back
(
{
3
,
4
,
32
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
{
3
,
4
,
8
,
128
,
256
,
{
3
,
3
,
3
},
{
14
,
14
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
1
,
1
,
1
}});
this
->
conv_params
.
push_back
(
{
3
,
4
,
128
,
128
,
256
,
{
1
,
1
,
1
},
{
3
,
3
,
3
},
{
1
,
1
,
1
},
{
1
,
1
,
1
},
{
0
,
0
,
0
},
{
0
,
0
,
0
}});
this
->
template
Run
<
3
>();
...
...
test/wmma_op/CMakeLists.txt
0 → 100644
View file @
e6715976
add_test_executable
(
test_wmma_op wmma_op.cpp
)
target_link_libraries
(
test_wmma_op PRIVATE utility
)
test/wmma_op/wmma_op.cpp
0 → 100644
View file @
e6715976
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#include <algorithm>
#include <cstdlib>
#include <iostream>
#include <numeric>
#include <tuple>
#include <vector>
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/element/element_wise_operation.hpp"
#include "test/wmma_op/wmma_op_util.hpp"
template
<
typename
SrcType
,
typename
DstType
,
typename
GPUAccType
,
typename
CPUAccType
,
ck
::
index_t
AccNum
>
bool
run_test
()
{
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
PassThrough
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
bool
pass
=
true
;
const
auto
matmul_default
=
ck
::
wmma_op_util
::
matmul
<
SrcType
,
DstType
,
GPUAccType
,
AccNum
>
;
const
auto
matmul_swizzle_a
=
ck
::
wmma_op_util
::
matmul_swizzle_a
<
SrcType
,
DstType
,
GPUAccType
,
AccNum
>
;
const
auto
wmma_kernel_container
=
std
::
make_tuple
(
matmul_default
,
matmul_swizzle_a
);
ck
::
static_for
<
0
,
2
,
1
>
{}([
&
](
auto
i
)
{
pass
&=
ck
::
wmma_op_util
::
TestWmma
<
decltype
(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
wmma_kernel_container
)),
SrcType
,
SrcType
,
DstType
,
GPUAccType
,
CPUAccType
,
decltype
(
Row
{}),
decltype
(
Col
{}),
decltype
(
Row
{}),
PassThrough
,
PassThrough
,
PassThrough
,
AccNum
>
{}(
std
::
get
<
ck
::
Number
<
i
>
{}
>
(
wmma_kernel_container
));
});
return
pass
?
1
:
0
;
}
int
main
(
int
,
char
*
[])
{
bool
pass
=
true
;
// clang-format off
// |SrcType |DstType |GPUAccType |CPUAccType |AccNum
pass
&=
run_test
<
ck
::
half_t
,
ck
::
half_t
,
float
,
float
,
8
>
();
pass
&=
run_test
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
float
,
8
>
();
pass
&=
run_test
<
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
ck
::
half_t
,
16
>
();
pass
&=
run_test
<
ck
::
bhalf_t
,
ck
::
bhalf_t
,
ck
::
bhalf_t
,
float
,
16
>
();
pass
&=
run_test
<
int8_t
,
int8_t
,
int32_t
,
int32_t
,
8
>
();
// clang-format on
std
::
cout
<<
"TestGemm ..... "
<<
(
pass
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
pass
?
0
:
1
;
}
test/wmma_op/wmma_op_util.hpp
0 → 100644
View file @
e6715976
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck/ck.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/library/utility/check_err.hpp"
#include "ck/library/utility/device_memory.hpp"
#include "ck/library/utility/host_tensor.hpp"
#include "ck/library/utility/host_tensor_generator.hpp"
#include "ck/library/reference_tensor_operation/cpu/reference_gemm.hpp"
#include "ck/utility/amd_wmma.hpp"
namespace
ck
{
namespace
wmma_op_util
{
template
<
typename
src_vec
,
typename
acc_vec
>
__device__
void
builtin_wmma_naive_selector
(
const
src_vec
&
,
const
src_vec
&
,
acc_vec
&
)
{
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
half16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>>
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_f32_16x16x16_f16_w32
<
16
,
16
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
bhalf16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>>
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
float
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_f32_16x16x16_bf16_w32
<
16
,
16
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
half16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
half_t
,
1
,
16
,
true
>>
(
const
half16_t
&
reg_a
,
const
half16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
half_t
,
1
,
16
,
true
>&
reg_c
)
{
intrin_wmma_f16_16x16x16_f16_w32
<
16
,
16
,
0
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
bhalf16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
bhalf_t
,
1
,
16
,
true
>>
(
const
bhalf16_t
&
reg_a
,
const
bhalf16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
bhalf_t
,
1
,
16
,
true
>&
reg_c
)
{
intrin_wmma_bf16_16x16x16_bf16_w32
<
16
,
16
,
0
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
template
<
>
__device__
void
builtin_wmma_naive_selector
<
int8x16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>>
(
const
int8x16_t
&
reg_a
,
const
int8x16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_i32_16x16x16_iu8_w32
<
16
,
16
,
true
,
true
,
false
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
#ifdef CK_EXPERIMENTAL_BIT_INT_EXTENSION_INT4
template
<
>
__device__
void
builtin_wmma_naive_selector
<
int4x16_t
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>>
(
const
int4x16_t
&
reg_a
,
const
int4x16_t
&
reg_b
,
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
int32_t
,
1
,
8
,
true
>&
reg_c
)
{
intrin_wmma_i32_16x16x16_iu4_w32
<
16
,
16
,
true
,
true
,
false
>::
Run
(
reg_a
,
reg_b
,
reg_c
.
GetVectorTypeReference
(
Number
<
0
>
{}));
}
#endif
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
// a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and
// b a_frag will store one column of the 16x16 matrix tile b_frag will store one row of the
// 16x16 matrix tile
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
src_vec
a_frag
=
{};
src_vec
b_frag
=
{};
// initialize c fragment to 0
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
acc_vec
c_thread_buf_
;
// lane is (0-31) mod 16 instead of 0-31 due to matrix replication in gfx11
// see https://atlvsp3.amd.com/sp3_gfx11_5_instructions.pdf page 482
// TODO: remove this dependency in gfx12 https://ontrack-internal.amd.com/browse/DEGFXSP3-101
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
// follow origin design
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
lane
+
ele
];
}
// sync threads, similar to mma_sync
__syncthreads
();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
__syncthreads
();
// wait for results, similar to mma_sync
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
r
=
ele
*
2
+
(
lIdx
/
16
);
// store results from unpacked c_thread_buf_ output
c
[
16
*
r
+
lane
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
ele
*
acc_num
/
8
>
{}]);
});
}
template
<
typename
src_t
,
typename
dst_t
,
typename
acc_t
,
index_t
acc_num
>
__global__
void
matmul_swizzle_a
(
const
src_t
*
a
,
const
src_t
*
b
,
dst_t
*
c
)
{
const
int
lIdx
=
threadIdx
.
x
;
using
src_vec
=
typename
vector_type
<
src_t
,
16
>::
type
;
src_vec
a_frag
=
{};
src_vec
b_frag
=
{};
using
acc_vec
=
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
acc_t
,
1
,
acc_num
,
true
>
;
acc_vec
c_thread_buf_
;
const
int
lane
=
lIdx
%
16
;
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
b_frag
[
ele
]
=
b
[
16
*
lane
+
ele
];
}
const
int
offset_m
=
(((
lane
&
1
)
<<
3
)
|
(
lane
>>
1
));
for
(
int
ele
=
0
;
ele
<
16
;
++
ele
)
{
a_frag
[
ele
]
=
a
[
16
*
offset_m
+
ele
];
}
__syncthreads
();
builtin_wmma_naive_selector
<
src_vec
,
acc_vec
>
(
a_frag
,
b_frag
,
c_thread_buf_
);
__syncthreads
();
static_for
<
0
,
8
,
1
>
{}([
&
](
auto
ele
)
{
const
int
blk
=
lIdx
/
16
;
const
int
r
=
ele
;
c
[
16
*
8
*
blk
+
16
*
r
+
lane
]
=
ck
::
type_convert
<
dst_t
>
(
c_thread_buf_
[
Number
<
ele
*
acc_num
/
8
>
{}]);
});
}
struct
GemmParams
{
GemmParams
()
:
M
(
16
),
N
(
16
),
K
(
16
),
StrideA
(
16
),
StrideB
(
16
),
StrideC
(
16
),
alpha
(
1
),
beta
(
0
)
{}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
float
alpha
;
float
beta
;
};
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
KernelType
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
>
bool
RunDeviceGEMM
(
KernelType
kernel
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
b_n_k_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpaceSize
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpaceSize
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_n_k_device_buf
.
ToDevice
(
B
.
mData
.
data
());
kernel
<<<
1
,
32
>>>
(
static_cast
<
const
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
const
BDataType
*>
(
b_n_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()));
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
return
true
;
}
template
<
typename
DeviceWmma
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
GPUAccDataType
,
typename
CPUAccDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
index_t
CAccNum
>
struct
TestWmma
{
auto
PrepareGemmTensor
(
const
ck
::
wmma_op_util
::
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_n_k
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
using
dataType
=
decltype
(
type
);
tensor
.
GenerateTensorValue
(
GeneratorTensor_2
<
dataType
>
{
-
5
,
5
});
};
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
b_n_k
,
BDataType
{});
return
std
::
make_tuple
(
a_m_k
,
b_n_k
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
const
DeviceWmma
&
wmma_kernel
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
// Arrange
ck
::
wmma_op_util
::
GemmParams
params
;
params
.
M
=
16
;
params
.
N
=
16
;
params
.
K
=
16
;
params
.
StrideA
=
16
;
params
.
StrideB
=
16
;
params
.
StrideC
=
16
;
auto
host_tensors
=
PrepareGemmTensor
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
CPUAccDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
ck
::
wmma_op_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
bool
is_supported
=
ck
::
wmma_op_util
::
RunDeviceGEMM
(
wmma_kernel
,
a
,
b
,
c_device
);
if
(
is_supported
)
{
// Assert
bool
res
=
false
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
bhalf_t
>::
value
)
{
// 0.5 Pixel Error Tolerance is introduced by Accumulator difference.
// BF16 WMMA Accumulator is in BF16 Type while On Host-side Accumulator is Float.
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
,
"Error: Incorrect results!"
,
0
,
1.0
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
{
std
::
cout
<<
"UNSUPPORTED CDataType"
<<
std
::
endl
;
}
return
res
;
}
else
{
return
true
;
}
}
};
}
// namespace wmma_op_util
}
// namespace ck
Prev
1
2
3
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment