Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
ebfa3921
Commit
ebfa3921
authored
Apr 30, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/fix_test' into add_mfma_f64
parents
58f4d821
579e8e76
Changes
21
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
333 additions
and
349 deletions
+333
-349
test/gemm/gemm_util.hpp
test/gemm/gemm_util.hpp
+333
-349
No files found.
test/gemm/gemm_util.hpp
View file @
ebfa3921
#ifndef GEMM_UTILS_HPP
#ifndef GEMM_UTILS_HPP
#define GEMM_UTILS_HPP
#define GEMM_UTILS_HPP
#include "check_err.hpp"
#include "check_err.hpp"
#include "config.hpp"
#include "config.hpp"
#include "device.hpp"
#include "device.hpp"
#include "host_tensor.hpp"
#include "host_tensor.hpp"
#include "host_tensor_generator.hpp"
#include "host_tensor_generator.hpp"
#include "reference_gemm.hpp"
#include "reference_gemm.hpp"
#include "tensor_layout.hpp"
#include "tensor_layout.hpp"
namespace
ck
{
namespace
ck
{
namespace
gemm_util
{
namespace
gemm_util
{
struct
GemmParams
struct
GemmParams
{
{
GemmParams
()
GemmParams
()
:
M
(
1024
),
N
(
1024
),
K
(
1024
),
StrideA
(
1024
),
StrideB
(
1024
),
StrideC
(
1024
),
alpha
(
1
),
beta
(
0
)
:
M
(
1024
),
N
(
1024
),
K
(
1024
),
StrideA
(
1024
),
StrideB
(
1024
),
StrideC
(
1024
),
alpha
(
1
),
beta
(
0
)
{
{
}
}
ck
::
index_t
M
;
ck
::
index_t
M
;
ck
::
index_t
N
;
ck
::
index_t
N
;
ck
::
index_t
K
;
ck
::
index_t
K
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideA
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideB
;
ck
::
index_t
StrideC
;
ck
::
index_t
StrideC
;
float
alpha
;
float
alpha
;
float
beta
;
float
beta
;
};
};
template
<
typename
GemmInstance
,
template
<
typename
GemmInstance
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
void
RunHostGEMM
(
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
auto
ref_gemm
=
GemmInstance
{};
auto
ref_gemm
=
GemmInstance
{};
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_invoker
=
ref_gemm
.
MakeInvoker
();
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
ref_argument
=
ref_gemm
.
MakeArgument
(
A
,
B
,
C
,
a_element_op
,
b_element_op
,
c_element_op
);
ref_invoker
.
Run
(
ref_argument
);
ref_invoker
.
Run
(
ref_argument
);
}
}
template
<
typename
DeviceGemmPtr_
,
template
<
typename
DeviceGemmPtr_
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
void
RunDeviceGEMM
(
DeviceGemmPtr_
&
gemmPtr
,
void
RunDeviceGEMM
(
DeviceGemmPtr_
&
gemmPtr
,
const
ck
::
gemm_util
::
GemmParams
&
params
,
const
ck
::
gemm_util
::
GemmParams
&
params
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
ADataType
>&
A
,
const
Tensor
<
BDataType
>&
B
,
const
Tensor
<
BDataType
>&
B
,
Tensor
<
CDataType
>&
C
,
Tensor
<
CDataType
>&
C
,
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
CElementwiseOperation
c_element_op
)
{
{
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpace
());
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
A
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
B
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
C
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
a_m_k_device_buf
.
ToDevice
(
A
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
B
.
mData
.
data
());
auto
invoker_ptr
=
gemmPtr
->
MakeInvokerPointer
();
auto
invoker_ptr
=
gemmPtr
->
MakeInvokerPointer
();
auto
argument_ptr
=
auto
argument_ptr
=
gemmPtr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
gemmPtr
->
MakeArgumentPointer
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
params
.
M
,
params
.
M
,
params
.
N
,
params
.
N
,
params
.
K
,
params
.
K
,
params
.
StrideA
,
params
.
StrideA
,
params
.
StrideB
,
params
.
StrideB
,
params
.
StrideC
,
params
.
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
);
c_element_op
);
if
(
!
gemmPtr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
if
(
!
gemmPtr
->
IsSupportedArgument
(
argument_ptr
.
get
()))
{
{
throw
std
::
runtime_error
(
throw
std
::
runtime_error
(
"wrong! device_gemm with the specified compilation parameters does "
"wrong! device_gemm with the specified compilation parameters does "
"not support this GEMM problem"
);
"not support this GEMM problem"
);
}
}
invoker_ptr
->
Run
(
argument_ptr
.
get
());
invoker_ptr
->
Run
(
argument_ptr
.
get
());
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
c_m_n_device_buf
.
FromDevice
(
C
.
mData
.
data
());
}
}
template
<
typename
DeviceGemmPtr_
,
template
<
typename
DeviceGemmPtr_
,
typename
ADataType
,
typename
ADataType
,
typename
BDataType
,
typename
BDataType
,
typename
CDataType
,
typename
CDataType
,
typename
AccDataType
,
typename
ALayout
,
typename
ALayout
,
typename
BLayout
,
typename
BLayout
,
typename
CLayout
,
typename
CLayout
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
typename
CElementwiseOperation
>
struct
TestGemm
struct
TestGemm
{
{
auto
PrepareGemmTensor
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
auto
PrepareGemmTensor
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
{
{
auto
f_host_tensor_descriptor
=
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
{
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
}
}
else
else
{
{
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
}
}
};
};
Tensor
<
ADataType
>
a_m_k
(
Tensor
<
ADataType
>
a_m_k
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
Tensor
<
BDataType
>
b_k_n
(
Tensor
<
BDataType
>
b_k_n
(
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
CDataType
>
c_m_n_host_result
(
Tensor
<
CDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
CDataType
>
c_m_n_device_result
(
Tensor
<
CDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
auto
f_generate_tensor_value
=
[](
auto
&
tensor
,
auto
type
)
{
auto
f_generate_tensor_value
=
[](
auto
&
desc
,
auto
type
)
{
using
dataType
=
decltype
(
type
);
using
dataType
=
decltype
(
type
);
tensor
.
GenerateTensorValue
(
GeneratorTensor_2
<
dataType
>
{
-
5
,
5
});
if
(
std
::
is_same
<
dataType
,
int8_t
>::
value
||
std
::
is_same
<
dataType
,
double
>::
value
)
};
{
desc
.
GenerateTensorValue
(
GeneratorTensor_2
<
dataType
>
{
-
5
,
5
});
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
}
f_generate_tensor_value
(
b_k_n
,
BDataType
{});
else
{
return
std
::
make_tuple
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c_m_n_device_result
);
desc
.
GenerateTensorValue
(
GeneratorTensor_3
<
dataType
>
{
-
0.5
,
0.5
});
}
}
};
auto
operator
()(
DeviceGemmPtr_
&
gemmPtr
)
{
f_generate_tensor_value
(
a_m_k
,
ADataType
{});
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
f_generate_tensor_value
(
b_k_n
,
BDataType
{});
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
std
::
cout
<<
gemmPtr
->
GetTypeString
()
<<
std
::
endl
;
return
std
::
make_tuple
(
a_m_k
,
b_k_n
,
c_m_n_host_result
,
c_m_n_device_result
);
}
// Arrange
ck
::
gemm_util
::
GemmParams
params
;
auto
operator
()(
DeviceGemmPtr_
&
gemmPtr
)
params
.
M
=
1024
;
{
params
.
N
=
1024
;
std
::
cout
<<
"data type: "
<<
typeid
(
ADataType
{}).
name
()
<<
std
::
endl
;
params
.
K
=
1024
;
std
::
cout
<<
"ALayout = "
<<
ALayout
{}.
name
<<
", BLayout = "
<<
BLayout
{}.
name
params
.
StrideA
=
1024
;
<<
", CLayout = "
<<
CLayout
{}.
name
<<
std
::
endl
;
params
.
StrideB
=
1024
;
std
::
cout
<<
gemmPtr
->
GetTypeString
()
<<
std
::
endl
;
params
.
StrideC
=
1024
;
// Arrange
auto
host_tensors
=
PrepareGemmTensor
(
params
);
ck
::
gemm_util
::
GemmParams
params
;
params
.
M
=
1024
;
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
params
.
N
=
1024
;
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
params
.
K
=
1024
;
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
params
.
StrideA
=
1024
;
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
params
.
StrideB
=
1024
;
params
.
StrideC
=
1024
;
auto
a_element_op
=
AElementwiseOperation
{};
auto
b_element_op
=
BElementwiseOperation
{};
auto
host_tensors
=
PrepareGemmTensor
(
params
);
auto
c_element_op
=
CElementwiseOperation
{};
const
Tensor
<
ADataType
>&
a
=
std
::
get
<
0
>
(
host_tensors
);
using
ReferenceGemmInstance
=
const
Tensor
<
BDataType
>&
b
=
std
::
get
<
1
>
(
host_tensors
);
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
Tensor
<
CDataType
>&
c_host
=
std
::
get
<
2
>
(
host_tensors
);
BDataType
,
Tensor
<
CDataType
>&
c_device
=
std
::
get
<
3
>
(
host_tensors
);
CDataType
,
AElementwiseOperation
,
auto
a_element_op
=
AElementwiseOperation
{};
BElementwiseOperation
,
auto
b_element_op
=
BElementwiseOperation
{};
CElementwiseOperation
>
;
auto
c_element_op
=
CElementwiseOperation
{};
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
using
ReferenceGemmInstance
=
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
ADataType
,
// Act
BDataType
,
ck
::
gemm_util
::
RunDeviceGEMM
(
CDataType
,
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
);
AccDataType
,
AElementwiseOperation
,
// Assert
BElementwiseOperation
,
bool
res
=
false
;
CElementwiseOperation
>
;
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
{
a
,
b
,
c_host
,
a_element_op
,
b_element_op
,
c_element_op
);
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
// Act
}
ck
::
gemm_util
::
RunDeviceGEMM
(
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
gemmPtr
,
params
,
a
,
b
,
c_device
,
a_element_op
,
b_element_op
,
c_element_op
);
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
// Assert
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
bool
res
=
false
;
}
if
(
std
::
is_same
<
CDataType
,
double
>::
value
)
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
{
{
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
}
}
else
if
(
std
::
is_same
<
CDataType
,
float
>::
value
)
{
return
res
;
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
}
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
};
}
else
if
(
std
::
is_same
<
CDataType
,
ck
::
half_t
>::
value
)
template
<
typename
DeviceGemmPtr_
,
{
typename
ALayout
,
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
typename
BLayout
,
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
typename
CLayout
,
}
typename
AElementwiseOperation
,
else
if
(
std
::
is_same
<
CDataType
,
int8_t
>::
value
)
typename
BElementwiseOperation
,
{
typename
CElementwiseOperation
>
res
=
ck
::
utils
::
check_err
(
c_device
.
mData
,
c_host
.
mData
);
struct
TestGemmBF16
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
{
}
using
BF16
=
ck
::
bhalf_t
;
return
res
;
auto
PrepareGemmTensorBF16
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
}
{
};
auto
f_host_tensor_descriptor
=
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
template
<
typename
DeviceGemmPtr_
,
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
typename
ALayout
,
{
typename
BLayout
,
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
typename
CLayout
,
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
typename
AElementwiseOperation
,
}
typename
BElementwiseOperation
,
else
typename
CElementwiseOperation
>
{
struct
TestGemmBF16
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
{
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
using
BF16
=
ck
::
bhalf_t
;
}
};
auto
PrepareGemmTensorBF16
(
const
ck
::
gemm_util
::
GemmParams
&
params
)
{
// use fp32 host kernel to verify bf16 device kernel
auto
f_host_tensor_descriptor
=
Tensor
<
BF16
>
a_m_k_bf16
(
[](
std
::
size_t
row
,
std
::
size_t
col
,
std
::
size_t
stride
,
auto
layout
)
{
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
if
(
std
::
is_same
<
decltype
(
layout
),
ck
::
tensor_layout
::
gemm
::
RowMajor
>::
value
)
Tensor
<
BF16
>
b_k_n_bf16
(
{
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
Tensor
<
BF16
>
c_m_n_device_bf16
(
std
::
vector
<
std
::
size_t
>
({
stride
,
1
}));
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
}
else
Tensor
<
float
>
a_m_k_fp32
(
{
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
return
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
row
,
col
}),
Tensor
<
float
>
b_k_n_fp32
(
std
::
vector
<
std
::
size_t
>
({
1
,
stride
}));
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
}
Tensor
<
float
>
c_m_n_host_fp32
(
};
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
float
>
c_m_n_device_fp32
(
// use fp32 host kernel to verify bf16 device kernel
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
Tensor
<
BF16
>
a_m_k_bf16
(
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
a_m_k_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
Tensor
<
BF16
>
b_k_n_bf16
(
b_k_n_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
Tensor
<
BF16
>
c_m_n_device_bf16
(
bf16_to_f32_
(
a_m_k_bf16
,
a_m_k_fp32
);
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
bf16_to_f32_
(
b_k_n_bf16
,
b_k_n_fp32
);
Tensor
<
float
>
a_m_k_fp32
(
return
std
::
make_tuple
(
a_m_k_bf16
,
f_host_tensor_descriptor
(
params
.
M
,
params
.
K
,
params
.
StrideA
,
ALayout
{}));
b_k_n_bf16
,
Tensor
<
float
>
b_k_n_fp32
(
c_m_n_device_bf16
,
f_host_tensor_descriptor
(
params
.
K
,
params
.
N
,
params
.
StrideB
,
BLayout
{}));
a_m_k_fp32
,
Tensor
<
float
>
c_m_n_host_fp32
(
b_k_n_fp32
,
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
c_m_n_host_fp32
,
Tensor
<
float
>
c_m_n_device_fp32
(
c_m_n_device_fp32
);
f_host_tensor_descriptor
(
params
.
M
,
params
.
N
,
params
.
StrideC
,
CLayout
{}));
}
a_m_k_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
auto
operator
()(
DeviceGemmPtr_
&
gemmPtr
)
b_k_n_bf16
.
GenerateTensorValue
(
GeneratorTensor_3
<
BF16
>
{
-
0.5
,
0.5
});
{
// Arrange
bf16_to_f32_
(
a_m_k_bf16
,
a_m_k_fp32
);
ck
::
gemm_util
::
GemmParams
params
;
bf16_to_f32_
(
b_k_n_bf16
,
b_k_n_fp32
);
params
.
M
=
1024
;
params
.
N
=
1024
;
return
std
::
make_tuple
(
a_m_k_bf16
,
params
.
K
=
1024
;
b_k_n_bf16
,
params
.
StrideA
=
1024
;
c_m_n_device_bf16
,
params
.
StrideB
=
1024
;
a_m_k_fp32
,
params
.
StrideC
=
1024
;
b_k_n_fp32
,
c_m_n_host_fp32
,
auto
host_tensors
=
PrepareGemmTensorBF16
(
params
);
c_m_n_device_fp32
);
const
Tensor
<
BF16
>&
a_bf16
=
std
::
get
<
0
>
(
host_tensors
);
}
const
Tensor
<
BF16
>&
b_bf16
=
std
::
get
<
1
>
(
host_tensors
);
Tensor
<
BF16
>&
c_device_bf16
=
std
::
get
<
2
>
(
host_tensors
);
auto
operator
()(
DeviceGemmPtr_
&
gemmPtr
)
Tensor
<
float
>&
a_fp32
=
std
::
get
<
3
>
(
host_tensors
);
{
Tensor
<
float
>&
b_fp32
=
std
::
get
<
4
>
(
host_tensors
);
// Arrange
Tensor
<
float
>&
c_host_fp32
=
std
::
get
<
5
>
(
host_tensors
);
ck
::
gemm_util
::
GemmParams
params
;
Tensor
<
float
>&
c_device_fp32
=
std
::
get
<
6
>
(
host_tensors
);
params
.
M
=
1024
;
params
.
N
=
1024
;
auto
a_element_op
=
AElementwiseOperation
{};
params
.
K
=
1024
;
auto
b_element_op
=
BElementwiseOperation
{};
params
.
StrideA
=
1024
;
auto
c_element_op
=
CElementwiseOperation
{};
params
.
StrideB
=
1024
;
params
.
StrideC
=
1024
;
// use fp32 host kernel to verify bf16 device kernel
using
ReferenceGemmInstance
=
auto
host_tensors
=
PrepareGemmTensorBF16
(
params
);
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
const
Tensor
<
BF16
>&
a_bf16
=
std
::
get
<
0
>
(
host_tensors
);
float
,
const
Tensor
<
BF16
>&
b_bf16
=
std
::
get
<
1
>
(
host_tensors
);
float
,
Tensor
<
BF16
>&
c_device_bf16
=
std
::
get
<
2
>
(
host_tensors
);
AElementwiseOperation
,
Tensor
<
float
>&
a_fp32
=
std
::
get
<
3
>
(
host_tensors
);
BElementwiseOperation
,
Tensor
<
float
>&
b_fp32
=
std
::
get
<
4
>
(
host_tensors
);
CElementwiseOperation
>
;
Tensor
<
float
>&
c_host_fp32
=
std
::
get
<
5
>
(
host_tensors
);
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
Tensor
<
float
>&
c_device_fp32
=
std
::
get
<
6
>
(
host_tensors
);
a_fp32
,
b_fp32
,
c_host_fp32
,
a_element_op
,
b_element_op
,
c_element_op
);
auto
a_element_op
=
AElementwiseOperation
{};
// Act
auto
b_element_op
=
BElementwiseOperation
{};
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
auto
c_element_op
=
CElementwiseOperation
{};
params
,
a_bf16
,
// use fp32 host kernel to verify bf16 device kernel
b_bf16
,
using
ReferenceGemmInstance
=
c_device_bf16
,
ck
::
tensor_operation
::
host
::
ReferenceGemm
<
float
,
a_element_op
,
float
,
b_element_op
,
float
,
c_element_op
);
float
,
AElementwiseOperation
,
bf16_to_f32_
(
c_device_bf16
,
c_device_fp32
);
BElementwiseOperation
,
CElementwiseOperation
>
;
// Assert
ck
::
gemm_util
::
RunHostGEMM
<
ReferenceGemmInstance
>
(
bool
res
=
ck
::
utils
::
check_err
(
a_fp32
,
b_fp32
,
c_host_fp32
,
a_element_op
,
b_element_op
,
c_element_op
);
c_device_fp32
.
mData
,
c_host_fp32
.
mData
,
"Error: incorrect results!"
,
1e-2
f
,
1e-3
f
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
// Act
ck
::
gemm_util
::
RunDeviceGEMM
(
gemmPtr
,
return
res
;
params
,
};
a_bf16
,
};
b_bf16
,
c_device_bf16
,
}
// namespace gemm_util
a_element_op
,
}
// namespace ck
b_element_op
,
#endif
c_element_op
);
bf16_to_f32_
(
c_device_bf16
,
c_device_fp32
);
// Assert
bool
res
=
ck
::
utils
::
check_err
(
c_device_fp32
.
mData
,
c_host_fp32
.
mData
,
"Error: incorrect results!"
,
1e-2
f
,
1e-3
f
);
std
::
cout
<<
(
res
?
"SUCCESS"
:
"FAILURE"
)
<<
std
::
endl
;
return
res
;
};
};
}
// namespace gemm_util
}
// namespace ck
#endif
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment