Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ccf01887
"vscode:/vscode.git/clone" did not exist on "ef41b7f61cfcf6e0b121aaf124eb07433d5c748c"
Commit
ccf01887
authored
Apr 29, 2022
by
ltqin
Browse files
Merge branch 'develop' into add_mfma_f64
parents
04397fa0
95e93430
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
391 additions
and
20 deletions
+391
-20
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+45
-20
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
...on/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
+3
-0
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+343
-0
No files found.
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
ccf01887
...
...
@@ -16,6 +16,31 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2CTileMap Block2CTileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for \link
* DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the computing of
* pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2CTileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
@@ -25,7 +50,7 @@ template <typename GridwiseGemm,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Compute
BasePr
tOfBatch
,
typename
Compute
PtrOffse
tOfBatch
,
typename
Block2CTileMap
,
bool
HasMainKBlockLoop
>
__global__
void
...
...
@@ -43,7 +68,7 @@ __global__ void
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Compute
BasePr
tOfBatch
compute_
base_ptr
_of_batch
_
,
const
Compute
PtrOffse
tOfBatch
compute_
ptr_offset
_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__))
...
...
@@ -52,11 +77,11 @@ __global__ void
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetA
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetA
PtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetB
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetB
PtrOffset
(
g_idx
)));
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_
base_ptr
_of_batch
_
.
GetC
BasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_
ptr_offset
_of_batch
.
GetC
PtrOffset
(
g_idx
)));
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
...
@@ -256,26 +281,26 @@ struct DeviceBatchedGemmXdl
return
globalblockid_to_m0_n0_block_cluster_adaptor
;
}
struct
Compute
BasePtr
OfStridedBatch
struct
Compute
PtrOffset
OfStridedBatch
{
Compute
BasePtr
OfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
)
Compute
PtrOffset
OfStridedBatch
(
index_t
BatchStrideA
,
index_t
BatchStrideB
,
index_t
BatchStrideC
)
:
BatchStrideA_
(
BatchStrideA
),
BatchStrideB_
(
BatchStrideB
),
BatchStrideC_
(
BatchStrideC
)
{
}
__host__
__device__
constexpr
long_index_t
GetA
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetA
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideA_
);
}
__host__
__device__
constexpr
long_index_t
GetB
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetB
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideB_
);
}
__host__
__device__
constexpr
long_index_t
GetC
BasePtr
(
index_t
g_idx
)
const
__host__
__device__
constexpr
long_index_t
GetC
PtrOffset
(
index_t
g_idx
)
const
{
return
g_idx
*
static_cast
<
long_index_t
>
(
BatchStrideC_
);
}
...
...
@@ -359,9 +384,9 @@ struct DeviceBatchedGemmXdl
DeviceBatchedGemmXdl
::
MakeBGridDescriptor_K0_N_K1
(
K
,
N
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceBatchedGemmXdl
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
)},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
compute_
base_ptr
_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
compute_
ptr_offset
_of_batch_
{
a_grid_desc_k0_m_k1_
.
GetElementSpaceSize
(),
b_grid_desc_k0_n_k1_
.
GetElementSpaceSize
(),
c_grid_desc_m_n_
.
GetElementSpaceSize
()},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
@@ -388,7 +413,7 @@ struct DeviceBatchedGemmXdl
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Compute
BasePtr
OfStridedBatch
compute_
base_ptr
_of_batch_
;
Compute
PtrOffset
OfStridedBatch
compute_
ptr_offset
_of_batch_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -448,7 +473,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Compute
BasePtr
OfStridedBatch
,
Compute
PtrOffset
OfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
true
>
;
...
...
@@ -467,7 +492,7 @@ struct DeviceBatchedGemmXdl
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_
base_ptr
_of_batch_
,
arg
.
compute_
ptr_offset
_of_batch_
,
arg
.
block_2_ctile_map_
);
}
else
...
...
@@ -482,7 +507,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
Compute
BasePtr
OfStridedBatch
,
Compute
PtrOffset
OfStridedBatch
,
remove_reference_t
<
Block2CTileMap
>
,
false
>
;
...
...
@@ -501,7 +526,7 @@ struct DeviceBatchedGemmXdl
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
compute_
base_ptr
_of_batch_
,
arg
.
compute_
ptr_offset
_of_batch_
,
arg
.
block_2_ctile_map_
);
}
...
...
include/ck/tensor_operation/gpu/device/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp
View file @
ccf01887
...
...
@@ -18,6 +18,9 @@ namespace ck {
namespace
tensor_operation
{
namespace
device
{
/*
* \see \link device_batched_gemm_xdl.hpp kernel_batched_gemm_xdlops_v2r3() \endlink.
*/
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
...
...
test/gemm/gemm_util.hpp
View file @
ccf01887
<<<<<<<
HEAD
#ifndef GEMM_UTILS_HPP
#define GEMM_UTILS_HPP
...
...
@@ -347,3 +348,345 @@ struct TestGemmBF16
}
// namespace gemm_util
}
// namespace ck
#endif
=======
#ifndef GEMM_UTILS_HPP
#define GEMM_UTILS_HPP
#include "check_err.hpp"
#include "config.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "reference_gemm.hpp"
#include "tensor_layout.hpp"
namespace
ck
{
namespace
gemm_util
{
struct
GemmParams
{
GemmParams
()
:
M
(
1024
),
N
(
1024
),
K
(
1024
),
StrideA
(
1024
),
StrideB
(
1024
),
StrideC
(
1024
),
alpha
(
1
),
beta
(
0
)
{
}
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
float
alpha
;
float
beta
;
};
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
}
template
<
typename
DeviceGemmPtr_
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
void
RunDeviceGEMM
(
DeviceGemmPtr_
&
gemmPtr
,
const
ck
::
gemm_util
::
GemmParams
&
params
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
auto
invoker_ptr
=
gemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
gemmPtr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
params
.
M
,
params
.
N
,
params
.
K
,
params
.
StrideA
,
params
.
StrideB
,
params
.
StrideC
,
a_element_op
,
b_element_op
,
c_element_op
);
if
(
!
gemmPtr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
}
invoker_ptr
->
Run
(
argument_ptr
.
get
());
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
}
template
<
typename
DeviceGemmPtr_
,
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
TestGemm
{
auto
PrepareGemmTensor
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
auto
f_generate_tensor_value
=
[](
auto
&
desc
,
auto
type
)
{
using
dataType
=
decltype
(
type
);
if
(
std
::
is_same
<
dataType
,
int8_t
>::
value
)
{
desc
.
GenerateTensorValue
(
GeneratorTensor_2
<
int8_t
>
{
-
5
,
5
});
}
else
{
desc
.
GenerateTensorValue
(
GeneratorTensor_3
<
dataType
>
{
-
0.5
,
0.5
});
}
};
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
f_generate_tensor_value
(
b_k_n
,
BDataType
{});
return
std
::
make_tuple
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c_m_n_device_result
);
}
auto
operator
()(
DeviceGemmPtr_
&
gemmPtr
)
{
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
std
::
cout
<<
gemmPtr
->
GetTypeString
()
<<
std
::
endl
;
// Arrange
ck
::
gemm_util
::
GemmParams
params
;
params
.
M
=
1024
;
params
.
N
=
1024
;
params
.
K
=
1024
;
params
.
StrideA
=
1024
;
params
.
StrideB
=
1024
;
params
.
StrideC
=
1024
;
auto
host_tensors
=
PrepareGemmTensor
(
params
);
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
BDataType
,
CDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
);
// Assert
bool
res
=
false
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
return
res
;
}
};
template
<
typename
DeviceGemmPtr_
,
typename
ALayout
,
typename
BLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
TestGemmBF16
{
using
BF16
=
ck
::
bhalf_t
;
auto
PrepareGemmTensorBF16
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
{
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
else
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
};
// use fp32 host kernel to verify bf16 device kernel
Tensor
<
BF16
>
a_m_k_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BF16
>
b_k_n_bf16
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
BF16
>
c_m_n_device_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
a_m_k_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
float
>
b_k_n_fp32
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
float
>
c_m_n_host_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_device_fp32
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
a_m_k_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
b_k_n_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
bf16_to_f32_
(
a_m_k_bf16
,
a_m_k_fp32
);
bf16_to_f32_
(
b_k_n_bf16
,
b_k_n_fp32
);
return
std
::
make_tuple
(
a_m_k_bf16
,
b_k_n_bf16
,
c_m_n_device_bf16
,
a_m_k_fp32
,
b_k_n_fp32
,
c_m_n_host_fp32
,
c_m_n_device_fp32
);
}
auto
operator
()(
DeviceGemmPtr_
&
gemmPtr
)
{
// Arrange
ck
::
gemm_util
::
GemmParams
params
;
params
.
M
=
1024
;
params
.
N
=
1024
;
params
.
K
=
1024
;
params
.
StrideA
=
1024
;
params
.
StrideB
=
1024
;
params
.
StrideC
=
1024
;
auto
host_tensors
=
PrepareGemmTensorBF16
(
params
);
const
Tensor
<
BF16
>&
a_bf16
=
std
::
get
<
0
>
(
host_tensors
);
const
Tensor
<
BF16
>&
b_bf16
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
BF16
>&
c_device_bf16
=
std
::
get
<
2
>
(
host_tensors
);
Tensor
<
float
>&
a_fp32
=
std
::
get
<
3
>
(
host_tensors
);
Tensor
<
float
>&
b_fp32
=
std
::
get
<
4
>
(
host_tensors
);
Tensor
<
float
>&
c_host_fp32
=
std
::
get
<
5
>
(
host_tensors
);
Tensor
<
float
>&
c_device_fp32
=
std
::
get
<
6
>
(
host_tensors
);
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
c_element_op
=
CElementwiseOperation
{};
// use fp32 host kernel to verify bf16 device kernel
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
float
,
float
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
;
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a_fp32
,
b_fp32
,
c_host_fp32
,
a_element_op
,
b_element_op
,
c_element_op
);
// Act
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
params
,
a_bf16
,
b_bf16
,
c_device_bf16
,
a_element_op
,
b_element_op
,
c_element_op
);
bf16_to_f32_
(
c_device_bf16
,
c_device_fp32
);
// Assert
bool
res
=
ck
::
utils
::
check_err
(
c_device_fp32
.
mData
,
c_host_fp32
.
mData
,
"Error: incorrect results!"
,
1e-2
f
,
1e-3
f
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
;
};
};
}
// namespace gemm_util
}
// namespace ck
#endif
>>>>>>>
develop
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment