Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
165e30cd
Commit
165e30cd
authored
Nov 21, 2021
by
Chao Liu
Browse files
adding bias add
parent
d5679ea6
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
128 additions
and
105 deletions
+128
-105
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp
...el/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp
+32
-62
example/2_gemm_xdl_bias_add/gemm_xdl_bias_add.cpp
example/2_gemm_xdl_bias_add/gemm_xdl_bias_add.cpp
+23
-1
example/2_gemm_xdl_bias_add/include/device_gemm_bias_add.hpp
example/2_gemm_xdl_bias_add/include/device_gemm_bias_add.hpp
+0
-42
example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp
.../2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp
+73
-0
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_v2r5.hpp
View file @
165e30cd
...
...
@@ -12,13 +12,14 @@
namespace
ck
{
#if CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VALUE
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
...
...
@@ -32,9 +33,13 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC
*
__restrict__
p_c0_grid
,
const
FloatC
*
__restrict__
p_c1_grid
,
const
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
...
...
@@ -48,75 +53,19 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_c0_grid
,
p_c1_grid
,
p_shared_block
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
#elif CK_EXPERIMENTAL_PASS_TENSOR_DESCRIPTOR_BY_VOID_POINTER
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatC
,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Block2CTileMap
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_gemm_xdlops_v2r5
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
void
CONSTANT
*
p_a_grid_desc_k0_m_k1
,
const
void
CONSTANT
*
p_b_grid_desc_k0_n_k1
,
const
void
CONSTANT
*
p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
void
CONSTANT
*
p_a_element_op
,
const
void
CONSTANT
*
p_b_element_op
,
const
void
CONSTANT
*
p_c_element_op
,
const
void
CONSTANT
*
p_block_2_ctile_map
)
{
constexpr
index_t
shared_block_size
=
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()
/
sizeof
(
FloatAB
);
const
auto
a_grid_desc_k0_m_k1
=
*
reinterpret_cast
<
const
AGridDesc_K0_M_K1
*>
(
cast_pointer_to_generic_address_space
(
p_a_grid_desc_k0_m_k1
));
const
auto
b_grid_desc_k0_n_k1
=
*
reinterpret_cast
<
const
BGridDesc_K0_N_K1
*>
(
cast_pointer_to_generic_address_space
(
p_b_grid_desc_k0_n_k1
));
const
auto
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
=
*
reinterpret_cast
<
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
*>
(
cast_pointer_to_generic_address_space
(
p_c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
));
const
auto
block_2_ctile_map
=
*
reinterpret_cast
<
const
Block2CTileMap
*>
(
cast_pointer_to_generic_address_space
(
p_block_2_ctile_map
));
const
auto
a_element_op
=
*
reinterpret_cast
<
const
AElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_a_element_op
));
const
auto
b_element_op
=
*
reinterpret_cast
<
const
BElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_b_element_op
));
const
auto
c_element_op
=
*
reinterpret_cast
<
const
CElementwiseOperation
*>
(
cast_pointer_to_generic_address_space
(
p_c_element_op
));
__shared__
FloatAB
p_shared_block
[
shared_block_size
];
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
,
p_b_grid
,
p_c_grid
,
p_shared_block
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
a_element_op
,
b_element_op
,
c_element_op
,
block_2_ctile_map
);
}
#endif
template
<
index_t
BlockSize
,
typename
FloatAB
,
...
...
@@ -126,6 +75,8 @@ template <index_t BlockSize,
typename
AGridDesc_K0_M_K1
,
typename
BGridDesc_K0_N_K1
,
typename
CGridDesc_M_N
,
typename
C0GridDesc_M_N
,
typename
C1GridDesc_M_N
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
...
...
@@ -281,8 +232,10 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
return
has_main_k0_block_loop
;
}
// TODO fix this
template
<
typename
CGridDesc_M_N_any
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
const
CGridDesc_M_N
_any
&
c_grid_desc_m_n
)
{
constexpr
auto
max_lds_align
=
K1
;
...
...
@@ -369,6 +322,13 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C0GridDesc_M_N
{}));
using
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C1GridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
template
<
bool
HasMainKBlockLoop
>
...
...
@@ -376,10 +336,14 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
Run
(
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
const
FloatC
*
__restrict__
p_c0_grid
,
const
FloatC
*
__restrict__
p_c1_grid
,
FloatAB
*
__restrict__
p_shared_block
,
const
AGridDesc_K0_M_K1
&
a_grid_desc_k0_m_k1
,
const
BGridDesc_K0_N_K1
&
b_grid_desc_k0_n_k1
,
const
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
&
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CElementwiseOperation
&
c_element_op
,
...
...
@@ -392,6 +356,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r5
auto
c_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c_grid
,
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
auto
c0_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c0_grid
,
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
auto
c1_grid_buf
=
make_dynamic_buffer
<
AddressSpaceEnum_t
::
Global
>
(
p_c1_grid
,
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2
.
GetElementSpaceSize
());
const
auto
K0
=
a_grid_desc_k0_m_k1
.
GetLength
(
I0
);
// divide block work by [M, N]
...
...
example/2_gemm_xdl_bias_add/gemm_xdl_bias_add.cpp
View file @
165e30cd
...
...
@@ -14,7 +14,9 @@
#include "device_base.hpp"
#include "example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp"
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * D[m] + gamma * E[m, n]
// C[m, n] = alpha(A[m, k] * B[k, n]) + beta * C0[m, n] + gamma * C1[m]
// assume C0 has same layout as C
// assume C1 is contiguous in memory
struct
PassThrough
{
...
...
@@ -175,9 +177,19 @@ int main(int argc, char* argv[])
Tensor
<
BDataType
>
c_m_n_host_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
Tensor
<
BDataType
>
c_m_n_device_result
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
// C0[m ,n]
Tensor
<
BDataType
>
c0_m_n
(
f_host_tensor_descriptor
(
M
,
N
,
StrideC
,
CLayout
{}));
// C1[m]
Tensor
<
CDataType
>
c1_m_n
(
HostTensorDescriptor
(
std
::
vector
<
std
::
size_t
>
({
static_cast
<
std
::
size_t
>
(
M
),
static_cast
<
std
::
size_t
>
(
N
)}),
std
::
vector
<
std
::
size_t
>
({
1
,
0
})));
std
::
cout
<<
"a_m_k: "
<<
a_m_k
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"b_k_n: "
<<
b_k_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c_m_n: "
<<
c_m_n_host_result
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c0_m_n: "
<<
c0_m_n
.
mDesc
<<
std
::
endl
;
std
::
cout
<<
"c1_m_n: "
<<
c1_m_n
.
mDesc
<<
std
::
endl
;
switch
(
init_method
)
{
...
...
@@ -185,19 +197,27 @@ int main(int argc, char* argv[])
case
1
:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_2
<
ADataType
>
{
-
5
,
5
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
BDataType
>
{
-
5
,
5
});
c0_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
CDataType
>
{
-
5
,
5
});
c1_m_n
.
GenerateTensorValue
(
GeneratorTensor_2
<
CDataType
>
{
-
5
,
5
});
break
;
default:
a_m_k
.
GenerateTensorValue
(
GeneratorTensor_3
<
ADataType
>
{
0.0
,
1.0
});
b_k_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
BDataType
>
{
-
0.5
,
0.5
});
c0_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
CDataType
>
{
0.0
,
1.0
});
c1_m_n
.
GenerateTensorValue
(
GeneratorTensor_3
<
CDataType
>
{
0.0
,
1.0
});
}
DeviceMem
a_m_k_device_buf
(
sizeof
(
ADataType
)
*
a_m_k
.
mDesc
.
GetElementSpace
());
DeviceMem
b_k_n_device_buf
(
sizeof
(
BDataType
)
*
b_k_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c_m_n_device_buf
(
sizeof
(
CDataType
)
*
c_m_n_device_result
.
mDesc
.
GetElementSpace
());
DeviceMem
c0_m_n_device_buf
(
sizeof
(
CDataType
)
*
c0_m_n
.
mDesc
.
GetElementSpace
());
DeviceMem
c1_m_n_device_buf
(
sizeof
(
CDataType
)
*
c1_m_n
.
mDesc
.
GetElementSpace
());
a_m_k_device_buf
.
ToDevice
(
a_m_k
.
mData
.
data
());
b_k_n_device_buf
.
ToDevice
(
b_k_n
.
mData
.
data
());
c_m_n_device_buf
.
ToDevice
(
c_m_n_device_result
.
mData
.
data
());
c0_m_n_device_buf
.
ToDevice
(
c0_m_n
.
mData
.
data
());
c1_m_n_device_buf
.
ToDevice
(
c1_m_n
.
mData
.
data
());
// do GEMM
auto
gemm
=
typename
DeviceGemmInstance
<
ADataType
,
...
...
@@ -214,6 +234,8 @@ int main(int argc, char* argv[])
auto
argument
=
gemm
.
MakeArgument
(
static_cast
<
ADataType
*>
(
a_m_k_device_buf
.
GetDeviceBuffer
()),
static_cast
<
BDataType
*>
(
b_k_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c0_m_n_device_buf
.
GetDeviceBuffer
()),
static_cast
<
CDataType
*>
(
c1_m_n_device_buf
.
GetDeviceBuffer
()),
M
,
N
,
K
,
...
...
example/2_gemm_xdl_bias_add/include/device_gemm_bias_add.hpp
deleted
100644 → 0
View file @
d5679ea6
#ifndef DEVICE_GEMM_HPP
#define DEVICE_GEMM_HPP
#include <iostream>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceGemmPtr
=
std
::
unique_ptr
<
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
#endif
example/2_gemm_xdl_bias_add/include/device_gemm_xdl_bias_add.hpp
View file @
165e30cd
...
...
@@ -129,6 +129,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
using
AGridDesc_K0_M_K1
=
decltype
(
MakeAGridDescriptor_K0_M_K1
(
1
,
1
,
1
));
using
BGridDesc_K0_N_K1
=
decltype
(
MakeBGridDescriptor_K0_N_K1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
C0GridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
// hardcoding
// TODO: fix this
using
C1GridDesc_M_N
=
decltype
(
make_naive_tensor_descriptor
(
make_tuple
(
1
,
1
),
make_tuple
(
I1
,
I0
)));
// TODO remove these hacks
static
constexpr
auto
a_k0_m_k1_grid_step_hacks
=
...
...
@@ -179,6 +185,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
AGridDesc_K0_M_K1
,
BGridDesc_K0_N_K1
,
CGridDesc_M_N
,
C0GridDesc_M_N
,
C1GridDesc_M_N
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -221,6 +229,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
using
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
CGridDesc_M_N
{}));
using
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C0GridDesc_M_N
{}));
using
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
=
decltype
(
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
C1GridDesc_M_N
{}));
using
Block2CTileMap
=
decltype
(
GridwiseGemm
::
MakeBlock2CTileMap
(
CGridDesc_M_N
{},
1
,
1
));
// Argument
...
...
@@ -229,6 +243,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
const
CDataType
*
p_c0_grid
,
const
CDataType
*
p_c1_grid
,
index_t
M
,
index_t
N
,
index_t
K
,
...
...
@@ -243,10 +259,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c0_grid_
{
p_c0_grid
},
p_c1_grid_
{
p_c1_grid
},
a_grid_desc_k0_m_k1_
{},
b_grid_desc_k0_n_k1_
{},
c_grid_desc_m_n_
{},
c0_grid_desc_m_n_
{},
c1_grid_desc_m_n_
{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
{},
block_2_ctile_map_
{},
M01_
{
M01
},
N01_
{
N01
},
...
...
@@ -261,12 +283,27 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
c_grid_desc_m_n_
=
DeviceGemmXdl_two_extra_source_reduce
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
// assume C0 has same layout as C
// TODO: fix this
c0_grid_desc_m_n_
=
DeviceGemmXdl_two_extra_source_reduce
::
MakeCGridDescriptor_M_N
(
M
,
N
,
StrideC
);
// hardcoding C1 layout
// TODO: fix this
c1_grid_desc_m_n_
=
make_naive_tensor_descriptor
(
make_tuple
(
M
,
N
),
make_tuple
(
I1
,
I0
));
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_k0_m_k1_
,
b_grid_desc_k0_n_k1_
,
c_grid_desc_m_n_
,
M01_
,
N01_
))
{
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c_grid_desc_m_n_
);
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c0_grid_desc_m_n_
);
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
=
GridwiseGemm
::
MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2
(
c1_grid_desc_m_n_
);
block_2_ctile_map_
=
GridwiseGemm
::
MakeBlock2CTileMap
(
c_grid_desc_m_n_
,
M01
,
N01
);
}
}
...
...
@@ -275,10 +312,16 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
const
CDataType
*
p_c0_grid_
;
const
CDataType
*
p_c1_grid_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
C0GridDesc_M_N
c0_grid_desc_m_n_
;
C1GridDesc_M_N
c1_grid_desc_m_n_
;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
;
Block2CTileMap
block_2_ctile_map_
;
index_t
M01_
;
index_t
N01_
;
...
...
@@ -305,6 +348,12 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
std
::
cout
<<
"arg.c_grid_desc_m_n_{ "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c0_grid_desc_m_n_{ "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c0_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
std
::
cout
<<
"arg.c1_grid_desc_m_n_{ "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I0
)
<<
", "
<<
arg
.
c1_grid_desc_m_n_
.
GetLength
(
I1
)
<<
"}"
<<
std
::
endl
;
}
if
(
!
GridwiseGemm
::
CheckValidity
(
arg
.
a_grid_desc_k0_m_k1_
,
...
...
@@ -335,6 +384,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -349,9 +402,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -367,6 +424,10 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
BGridDesc_K0_N_K1
>
,
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
C0GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
remove_reference_t
<
DeviceGemmXdl_two_extra_source_reduce
::
C1GridDesc_M0_N0_M1_N1_M2_M3_M4_N2
>
,
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
...
...
@@ -381,9 +442,13 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c0_grid_
,
arg
.
p_c1_grid_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c0_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c1_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
...
...
@@ -424,6 +489,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
const
BDataType
*
p_b
,
CDataType
*
p_c
,
const
CDataType
*
p_c0
,
const
CDataType
*
p_c1
,
index_t
M
,
index_t
N
,
index_t
K
,
...
...
@@ -437,6 +504,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
return
Argument
{
p_a
,
p_b
,
p_c
,
p_c0
,
p_c1
,
M
,
N
,
K
,
...
...
@@ -456,6 +525,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
const
void
*
p_c0
,
const
void
*
p_c1
,
index_t
M
,
index_t
N
,
index_t
K
,
...
...
@@ -469,6 +540,8 @@ struct DeviceGemmXdl_two_extra_source_reduce : public BaseOperator
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
const
CDataType
*>
(
p_c0
),
static_cast
<
const
CDataType
*>
(
p_c1
),
M
,
N
,
K
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment