Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b11569f
Commit
0b11569f
authored
Jul 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute
parents
e8d3a0fb
fa9a0a5c
Changes
554
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
389 additions
and
195 deletions
+389
-195
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
.../tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
...ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+3
-0
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+3
-0
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+3
-0
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
...gpu/device/convolution_backward_weight_specialization.hpp
+3
-0
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
...k/tensor_operation/gpu/device/device_5ary_elementwise.hpp
+59
-42
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
+45
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+198
-125
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+15
-12
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+27
-16
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
...
...
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION
...
...
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
...
@@ -7,7 +10,7 @@
...
@@ -7,7 +10,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_
ba
se.hpp"
#include "ck/tensor_operation/gpu/device/device_
elementwi
se.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/device_utility/kernel_launch.hpp"
...
@@ -32,7 +35,7 @@ template <typename ADataType,
...
@@ -32,7 +35,7 @@ template <typename ADataType,
index_t
DScalarPerVector
,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
index_t
FScalarPerVector
>
struct
Device5AryElementwise
:
public
BaseOpera
tor
struct
Device5AryElementwise
:
public
DeviceElementwise
<
5
,
1
,
NDim
,
ElementwiseFunc
tor
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
...
@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return
true
;
return
true
;
};
};
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
std
::
array
<
const
void
*
,
5
>
p_inputs
,
const
BDataType
*
p_b
,
std
::
array
<
void
*
,
1
>
p_outputs
,
const
CDataType
*
p_c
,
const
DDataType
*
p_d
,
const
EDataType
*
p_e
,
FDataType
*
p_f
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
index_t
>
b_strides
,
...
@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
...
@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std
::
vector
<
index_t
>
f_strides
,
std
::
vector
<
index_t
>
f_strides
,
ElementwiseFunctor
functor
)
ElementwiseFunctor
functor
)
{
{
return
Argument
{
p_a
,
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_inputs
[
0
])
,
p_b
,
static_cast
<
const
BDataType
*>
(
p_inputs
[
1
])
,
p_c
,
static_cast
<
const
CDataType
*>
(
p_inputs
[
2
])
,
p_d
,
static_cast
<
const
DDataType
*>
(
p_inputs
[
3
])
,
p_e
,
static_cast
<
const
EDataType
*>
(
p_inputs
[
4
])
,
p_f
,
static_cast
<
FDataType
*>
(
p_outputs
[
0
])
,
lengths
,
lengths
,
a_strides
,
a_strides
,
b_strides
,
b_strides
,
...
@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
...
@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor
};
functor
};
}
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
const
void
*
p_b
,
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
5
>
p_inputs
,
const
void
*
p_c
,
std
::
array
<
void
*
,
1
>
p_outputs
,
const
void
*
p_d
,
std
::
vector
<
index_t
>
lengths
,
const
void
*
p_e
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
void
*
p_f
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
std
::
vector
<
index_t
>
lengths
,
ElementwiseFunctor
functor
)
override
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
index_t
>
c_strides
,
std
::
vector
<
index_t
>
d_strides
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
f_strides
,
ElementwiseFunctor
functor
)
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
inputs
[
0
]
),
static_cast
<
const
BDataType
*>
(
p_
b
),
static_cast
<
const
BDataType
*>
(
p_
inputs
[
1
]
),
static_cast
<
const
CDataType
*>
(
p_
c
),
static_cast
<
const
CDataType
*>
(
p_
inputs
[
2
]
),
static_cast
<
const
DDataType
*>
(
p_
d
),
static_cast
<
const
DDataType
*>
(
p_
inputs
[
3
]
),
static_cast
<
const
EDataType
*>
(
p_
e
),
static_cast
<
const
EDataType
*>
(
p_
inputs
[
4
]
),
static_cast
<
FDataType
*>
(
p_
f
),
static_cast
<
FDataType
*>
(
p_
outputs
[
0
]
),
lengths
,
lengths
,
a
_strides
,
input
_strides
[
0
]
,
b
_strides
,
input
_strides
[
1
]
,
c
_strides
,
input
_strides
[
2
]
,
d
_strides
,
input
_strides
[
3
]
,
e
_strides
,
input
_strides
[
4
]
,
f
_strides
,
output
_strides
[
0
]
,
functor
);
functor
);
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
};
{
return
std
::
make_unique
<
Invoker
>
();
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Device5aryElementwise"
<<
"<"
<<
"NDim = "
<<
NDim
<<
"MPerThread = "
<<
MPerThread
<<
"AScalarPerVector = "
<<
AScalarPerVector
<<
"BScalarPerVector = "
<<
BScalarPerVector
<<
"CScalarPerVector = "
<<
CScalarPerVector
<<
"DScalarPerVector = "
<<
DScalarPerVector
<<
"EScalarPerVector = "
<<
EScalarPerVector
<<
"FScalarPerVector = "
<<
FScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <string>
#include <string>
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
0 → 100644
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
Batch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmPtr
=
std
::
unique_ptr
<
DeviceBatchedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
...
@@ -20,16 +23,16 @@ namespace device {
...
@@ -20,16 +23,16 @@ namespace device {
template
<
typename
GridwiseGemm
,
template
<
typename
GridwiseGemm
,
typename
FloatAB
,
typename
FloatAB
,
typename
FloatC
,
typename
FloatC
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
InElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceAccElementwiseOperation
s
,
typename
AGridDesc_AK0_M_AK1
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
D
GridDescriptor_MBlock_MPerBlock
,
typename
Reduce
GridDescriptor_MBlock_MPerBlock
,
typename
ComputeBasePrtOfBatch
,
typename
ComputeBasePrtOfBatch
,
typename
Block2CTileMap
,
typename
Block2CTileMap
,
bool
HasMainK0BlockLoop
>
bool
HasMainK0BlockLoop
>
...
@@ -41,18 +44,18 @@ __global__ void
...
@@ -41,18 +44,18 @@ __global__ void
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_a_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
const
FloatAB
*
__restrict__
p_b_grid
,
FloatC
*
__restrict__
p_c_grid
,
FloatC
*
__restrict__
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
const
index_t
batch_count
,
const
index_t
batch_count
,
const
AElementwiseOperation
a_element_op
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CElementwiseOperation
c_element_op
,
const
CElementwiseOperation
c_element_op
,
const
Dxs
InElementwiseOperation
dxs
_in_element_op
,
const
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
const
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
const
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1
,
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
const
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
const
D
GridDescriptor_MBlock_MPerBlock
d
_grid_desc_mblock_mperblock
,
const
Reduce
GridDescriptor_MBlock_MPerBlock
reduce
_grid_desc_mblock_mperblock
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
ComputeBasePrtOfBatch
compute_base_ptr_of_batch_
,
const
Block2CTileMap
block_2_ctile_map
)
const
Block2CTileMap
block_2_ctile_map
)
{
{
...
@@ -68,10 +71,10 @@ __global__ void
...
@@ -68,10 +71,10 @@ __global__ void
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
c_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch_
.
GetCBasePtr
(
g_idx
)));
static_for
<
0
,
p_
d
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
static_for
<
0
,
p_
reduce
s_grid
.
Size
(),
1
>
{}([
&
](
auto
In
)
{
const
long_index_t
d_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
d_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch_
.
GetDBasePtr
(
g_idx
,
In
)));
static_cast
<
long_index_t
>
(
compute_base_ptr_of_batch_
.
GetDBasePtr
(
g_idx
,
In
)));
p_
d
s_grid
(
In
)
=
p_
d
s_grid
(
In
)
+
d_batch_offset
;
p_
reduce
s_grid
(
In
)
=
p_
reduce
s_grid
(
In
)
+
d_batch_offset
;
});
});
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
...
@@ -79,36 +82,36 @@ __global__ void
...
@@ -79,36 +82,36 @@ __global__ void
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseGemm
::
template
Run
<
HasMainK0BlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_c_grid
+
c_batch_offset
,
p_
d
s_grid
,
p_
reduce
s_grid
,
p_shared
,
p_shared
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
dxs
_in_element_op
,
reduce
_in_element_op
s
,
dxs
_out_element_op
,
reduce
_out_element_op
s
,
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
c_grid_desc_mblock_mperblock_nblock_nperblock
,
d
_grid_desc_mblock_mperblock
,
reduce
_grid_desc_mblock_mperblock
,
block_2_ctile_map
);
block_2_ctile_map
);
#else
#else
ignore
=
p_a_grid
;
ignore
=
p_a_grid
;
ignore
=
p_b_grid
;
ignore
=
p_b_grid
;
ignore
=
p_c_grid
;
ignore
=
p_c_grid
;
ignore
=
p_
d
s_grid
;
ignore
=
p_
reduce
s_grid
;
ignore
=
batch_count
;
ignore
=
batch_count
;
ignore
=
a_element_op
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
b_element_op
;
ignore
=
c_element_op
;
ignore
=
c_element_op
;
ignore
=
dxs
_in_element_op
;
ignore
=
reduce
_in_element_op
s
;
ignore
=
dxs
_out_element_op
;
ignore
=
reduce
_out_element_op
s
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
a_grid_desc_ak0_m_ak1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
b_grid_desc_bk0_n_bk1
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
c_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
d
_grid_desc_mblock_mperblock
;
ignore
=
reduce
_grid_desc_mblock_mperblock
;
ignore
=
compute_base_ptr_of_batch_
;
ignore
=
compute_base_ptr_of_batch_
;
ignore
=
block_2_ctile_map
;
ignore
=
block_2_ctile_map
;
#endif
// end of if defined (defined(__gfx908__) || defined(__gfx90a__))
#endif
}
}
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
// Note: inter-wave loop scheduler is rolled out to c-shuffle version first. Becuase non c-shuffle
...
@@ -123,14 +126,14 @@ template <typename ALayout,
...
@@ -123,14 +126,14 @@ template <typename ALayout,
typename
GemmAccDataType
,
typename
GemmAccDataType
,
typename
CShuffleDataType
,
typename
CShuffleDataType
,
typename
ReduceAccDataType
,
typename
ReduceAccDataType
,
typename
D
PtrsGlobal
,
typename
Reduce
PtrsGlobal
,
typename
AElementwiseOperation
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
,
typename
CElementwiseOperation
,
typename
Dxs
ReduceOperation
,
typename
ReduceOperation
s
,
typename
Dxs
InElementwiseOperation
,
typename
Reduce
InElementwiseOperation
s
,
typename
Dxs
ReduceAccElementwiseOperation
,
typename
ReduceAccElementwiseOperation
s
,
typename
D
GlobalMemoryDataOperation
,
typename
Reduce
GlobalMemoryDataOperation
,
GemmSpecialization
GemmSpec
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
BlockSize
,
...
@@ -165,12 +168,7 @@ template <typename ALayout,
...
@@ -165,12 +168,7 @@ template <typename ALayout,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadLds2VGprCopySrcDstScalarPerVector_NPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
index_t
CReduceThreadVgpr2GlobalCopySrcDstScalarPerVector_MPerBlock
,
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()>
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
struct
DeviceBatchedGemmReduce_Xdl_CShuffle
:
public
DeviceGemmReduce
<
0
,
ReduceOperations
::
Size
()
>
:
public
DeviceGemmReduce
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
DxsInElementwiseOperation
,
DxsReduceAccElementwiseOperation
>
{
{
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
using
DeviceOp
=
DeviceBatchedGemmReduce_Xdl_CShuffle
;
...
@@ -443,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -443,7 +441,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
}
// assume D is packed tensor
// assume D is packed tensor
static
auto
Make
D
GridDescriptor_M
(
index_t
MRaw
)
static
auto
Make
Reduce
GridDescriptor_M
(
index_t
MRaw
)
{
{
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
const
auto
d_grid_desc_mraw
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
MRaw
));
...
@@ -471,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -471,7 +469,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
AGridDesc_AK0_M_AK1
=
decltype
(
MakeAGridDescriptor_AK0_M_AK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
BGridDesc_BK0_N_BK1
=
decltype
(
MakeBGridDescriptor_BK0_N_BK1
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
CGridDesc_M_N
=
decltype
(
MakeCGridDescriptor_M_N
(
1
,
1
,
1
));
using
D
GridDesc_M
=
decltype
(
Make
D
GridDescriptor_M
(
1
));
using
Reduce
GridDesc_M
=
decltype
(
Make
Reduce
GridDescriptor_M
(
1
));
struct
ComputeBasePtrOfStridedBatch
struct
ComputeBasePtrOfStridedBatch
{
{
...
@@ -524,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -524,19 +522,19 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
CShuffleDataType
,
CShuffleDataType
,
CDataType
,
CDataType
,
ReduceAccDataType
,
ReduceAccDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
Dxs
ReduceOperation
,
ReduceOperation
s
,
Dxs
InElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
Dxs
ReduceAccElementwiseOperation
,
ReduceAccElementwiseOperation
s
,
InMemoryDataOperationEnum
::
Set
,
InMemoryDataOperationEnum
::
Set
,
D
GlobalMemoryDataOperation
,
Reduce
GlobalMemoryDataOperation
,
AGridDesc_AK0_M_AK1
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
BGridDesc_BK0_N_BK1
,
CGridDesc_M_N
,
CGridDesc_M_N
,
D
GridDesc_M
,
Reduce
GridDesc_M
,
NumGemmKPrefetchStage
,
NumGemmKPrefetchStage
,
BlockSize
,
BlockSize
,
MPerBlock
,
MPerBlock
,
...
@@ -579,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -579,7 +577,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
Argument
(
const
ADataType
*
p_a_grid
,
Argument
(
const
ADataType
*
p_a_grid
,
const
BDataType
*
p_b_grid
,
const
BDataType
*
p_b_grid
,
CDataType
*
p_c_grid
,
CDataType
*
p_c_grid
,
D
PtrsGlobal
p_
d
s_grid
,
Reduce
PtrsGlobal
p_
reduce
s_grid
,
index_t
MRaw
,
index_t
MRaw
,
index_t
NRaw
,
index_t
NRaw
,
index_t
KRaw
,
index_t
KRaw
,
...
@@ -589,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -589,31 +587,31 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
Dxs
InElementwiseOperation
dxs
_in_element_op
,
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
,
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op
,
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
,
index_t
Batch
Count
)
index_t
Batch
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
p_
d
s_grid_
{
p_
d
s_grid
},
p_
reduce
s_grid_
{
p_
reduce
s_grid
},
Batch
Count
_
(
Batch
Count
),
Batch_
(
Batch
),
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
a_grid_desc_ak0_m_ak1_
{
DeviceOp
::
MakeAGridDescriptor_AK0_M_AK1
(
MRaw
,
KRaw
,
StrideA
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
b_grid_desc_bk0_n_bk1_
{
DeviceOp
::
MakeBGridDescriptor_BK0_N_BK1
(
KRaw
,
NRaw
,
StrideB
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
c_grid_desc_m_n_
{
DeviceOp
::
MakeCGridDescriptor_M_N
(
MRaw
,
NRaw
,
StrideC
)},
d
_grid_desc_m_
{
DeviceOp
::
Make
D
GridDescriptor_M
(
MRaw
)},
reduce
_grid_desc_m_
{
DeviceOp
::
Make
Reduce
GridDescriptor_M
(
MRaw
)},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
c_grid_desc_mblock_mperblock_nblock_nperblock_
{},
d
_grid_desc_mblock_mperblock_
{},
reduce
_grid_desc_mblock_mperblock_
{},
compute_base_ptr_of_batch_
{
compute_base_ptr_of_batch_
{
type_convert
<
index_t
>
(
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
a_grid_desc_ak0_m_ak1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
b_grid_desc_bk0_n_bk1_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
c_grid_desc_m_n_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
c_grid_desc_m_n_
.
GetElementSpaceSize
()),
type_convert
<
index_t
>
(
d
_grid_desc_m_
.
GetElementSpaceSize
())},
type_convert
<
index_t
>
(
reduce
_grid_desc_m_
.
GetElementSpaceSize
())},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
block_2_ctile_map_
{
GridwiseGemm
::
MakeDefaultBlock2CTileMap
(
c_grid_desc_m_n_
)},
a_element_op_
{
a_element_op
},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
b_element_op_
{
b_element_op
},
c_element_op_
{
c_element_op
},
c_element_op_
{
c_element_op
},
dxs
_in_element_op_
{
dxs
_in_element_op
},
reduce
_in_element_op
s
_
{
reduce
_in_element_op
s
},
dxs
_out_element_op_
{
dxs
_out_element_op
}
reduce
_out_element_op
s
_
{
reduce
_out_element_op
s
}
{
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_ak0_m_ak1_
,
b_grid_desc_bk0_n_bk1_
,
b_grid_desc_bk0_n_bk1_
,
...
@@ -624,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -624,8 +622,8 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
GridwiseGemm
::
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
c_grid_desc_m_n_
);
c_grid_desc_m_n_
);
d
_grid_desc_mblock_mperblock_
=
reduce
_grid_desc_mblock_mperblock_
=
GridwiseGemm
::
Make
D
GridDescriptor_MBlock_MPerBlock
(
d
_grid_desc_m_
);
GridwiseGemm
::
Make
Reduce
GridDescriptor_MBlock_MPerBlock
(
reduce
_grid_desc_m_
);
}
}
}
}
...
@@ -633,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -633,22 +631,23 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
D
PtrsGlobal
p_
d
s_grid_
;
Reduce
PtrsGlobal
p_
reduce
s_grid_
;
index_t
Batch
Count
_
;
index_t
Batch_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
D
GridDesc_M
d
_grid_desc_m_
;
Reduce
GridDesc_M
reduce
_grid_desc_m_
;
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
c_grid_desc_mblock_mperblock_nblock_nperblock_
;
typename
GridwiseGemm
::
DGridDescriptor_MBlock_MPerBlock
d_grid_desc_mblock_mperblock_
;
typename
GridwiseGemm
::
ReduceGridDescriptor_MBlock_MPerBlock
reduce_grid_desc_mblock_mperblock_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
ComputeBasePtrOfStridedBatch
compute_base_ptr_of_batch_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
typename
GridwiseGemm
::
DefaultBlock2CTileMap
block_2_ctile_map_
;
AElementwiseOperation
a_element_op_
;
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
BElementwiseOperation
b_element_op_
;
CElementwiseOperation
c_element_op_
;
CElementwiseOperation
c_element_op_
;
Dxs
InElementwiseOperation
dxs
_in_element_op_
;
Reduce
InElementwiseOperation
s
reduce
_in_element_op
s
_
;
Dxs
ReduceAccElementwiseOperation
dxs
_out_element_op_
;
ReduceAccElementwiseOperation
s
reduce
_out_element_op
s
_
;
};
};
// Invoker
// Invoker
...
@@ -660,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -660,7 +659,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
{
{
#if 0
#if 0
{
{
std::cout << "arg.Batch
Count
_ = " << arg.Batch
Count
_ << std::endl;
std::cout << "arg.Batch_ = " << arg.Batch_ << std::endl;
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
std::cout << "arg.a_grid_desc_ak0_m_ak1_{"
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
<< arg.a_grid_desc_ak0_m_ak1_.GetLength(I0) << ", "
...
@@ -675,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -675,7 +674,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
std::cout << "arg.
d
_grid_desc_m_{ " << arg.
d
_grid_desc_m_.GetLength(I0) << "}"
std::cout << "arg.
reduce
_grid_desc_m_{ " << arg.
reduce
_grid_desc_m_.GetLength(I0) << "}"
<< std::endl;
<< std::endl;
}
}
#endif
#endif
...
@@ -689,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -689,7 +688,7 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch
Count
_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch_
;
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_ak0_m_ak1_
.
GetLength
(
I2
);
...
@@ -701,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -701,16 +700,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
Dxs
ReduceAccElementwiseOperation
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
true
>
;
true
>
;
...
@@ -724,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -724,17 +723,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
...
@@ -744,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -744,16 +743,16 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
GridwiseGemm
,
GridwiseGemm
,
ADataType
,
// TODO: distiguish A/B datatype
ADataType
,
// TODO: distiguish A/B datatype
CDataType
,
CDataType
,
D
PtrsGlobal
,
Reduce
PtrsGlobal
,
AElementwiseOperation
,
AElementwiseOperation
,
BElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
,
CElementwiseOperation
,
Dxs
InElementwiseOperation
,
Reduce
InElementwiseOperation
s
,
Dxs
ReduceAccElementwiseOperation
,
ReduceAccElementwiseOperation
s
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
AGridDesc_AK0_M_AK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
DeviceOp
::
BGridDesc_BK0_N_BK1
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
GridwiseGemm
::
D
GridDescriptor_MBlock_MPerBlock
,
typename
GridwiseGemm
::
Reduce
GridDescriptor_MBlock_MPerBlock
,
ComputeBasePtrOfStridedBatch
,
ComputeBasePtrOfStridedBatch
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
typename
GridwiseGemm
::
DefaultBlock2CTileMap
,
false
>
;
false
>
;
...
@@ -767,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -767,17 +766,17 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
p_
d
s_grid_
,
arg
.
p_
reduce
s_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_element_op_
,
arg
.
a_element_op_
,
arg
.
b_element_op_
,
arg
.
b_element_op_
,
arg
.
c_element_op_
,
arg
.
c_element_op_
,
arg
.
dxs
_in_element_op_
,
arg
.
reduce
_in_element_op
s
_
,
arg
.
dxs
_out_element_op_
,
arg
.
reduce
_out_element_op
s
_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
a_grid_desc_ak0_m_ak1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
b_grid_desc_bk0_n_bk1_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
c_grid_desc_mblock_mperblock_nblock_nperblock_
,
arg
.
d
_grid_desc_mblock_mperblock_
,
arg
.
reduce
_grid_desc_mblock_mperblock_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
compute_base_ptr_of_batch_
,
arg
.
block_2_ctile_map_
);
arg
.
block_2_ctile_map_
);
}
}
...
@@ -821,39 +820,77 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -821,39 +820,77 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
}
}
}
}
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
constexpr
int
NumReduce
=
ReduceOperations
::
Size
();
const
BDataType
*
p_b
,
static
auto
MakeArgument
(
const
void
*
p_a
,
CDataType
*
p_c
,
const
void
*
p_b
,
DPtrsGlobal
p_dxs
,
const
void
*
p_bias
,
index_t
MRaw
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
index_t
NRaw
,
void
*
p_c
,
index_t
KRaw
,
std
::
array
<
void
*
,
NumReduce
>
p_reduces
,
index_t
StrideA
,
ck
::
index_t
M
,
index_t
StrideB
,
ck
::
index_t
N
,
index_t
StrideC
,
ck
::
index_t
K
,
AElementwiseOperation
a_element_op
,
ck
::
index_t
StrideA
,
BElementwiseOperation
b_element_op
,
ck
::
index_t
StrideB
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
StrideC
,
DxsInElementwiseOperation
dxs_in_element_op
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
DxsReduceAccElementwiseOperation
dxs_out_element_op
,
std
::
array
<
void
*
,
3
>
gemm_element_ops
,
index_t
BatchCount
)
std
::
array
<
void
*
,
0
>
d_element_ops
,
std
::
array
<
void
*
,
NumReduce
>
reduce_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce_out_element_op
,
index_t
Batch
)
{
{
return
Argument
{
p_a
,
(
void
)
p_bias
;
p_b
,
(
void
)
p_ds
;
p_c
,
(
void
)
StrideDs
;
p_dxs
,
(
void
)
d_element_ops
;
MRaw
,
NRaw
,
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
KRaw
,
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
reduce_tuple
,
M
,
N
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
dxs
_in_element_op
,
reduce
_in_element_op
s
,
dxs
_out_element_op
,
reduce
_out_element_op
s
,
Batch
Count
};
Batch
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -862,38 +899,74 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
...
@@ -862,38 +899,74 @@ struct DeviceBatchedGemmReduce_Xdl_CShuffle
std
::
unique_ptr
<
BaseArgument
>
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
const
void
*
p_b
,
const
void
*
p_bias
,
std
::
array
<
const
void
*
,
0
>
p_ds
,
void
*
p_c
,
void
*
p_c
,
void
*
p_dx
s
,
std
::
array
<
void
*
,
NumReduce
>
p_reduce
s
,
index_t
M
Raw
,
ck
::
index_t
M
,
index_t
N
Raw
,
ck
::
index_t
N
,
index_t
K
Raw
,
ck
::
index_t
K
,
index_t
StrideA
,
ck
::
index_t
StrideA
,
index_t
StrideB
,
ck
::
index_t
StrideB
,
index_t
StrideC
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
std
::
array
<
ck
::
index_t
,
0
>
StrideDs
,
BElementwiseOperation
b
_element_op
,
std
::
array
<
void
*
,
3
>
gemm
_element_op
s
,
CElementwiseOperation
c
_element_op
,
std
::
array
<
void
*
,
0
>
d
_element_op
s
,
DxsInElementwiseOperation
dxs
_in_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_in_element_op
,
DxsReduceAccElementwiseOperation
dxs
_out_element_op
,
std
::
array
<
void
*
,
NumReduce
>
reduce
_out_element_op
,
index_t
Batch
Count
)
override
index_t
Batch
=
1
)
override
{
{
DPtrsGlobal
dxs_tuple
=
*
(
static_cast
<
DPtrsGlobal
*>
(
p_dxs
));
(
void
)
p_bias
;
(
void
)
p_ds
;
(
void
)
StrideDs
;
(
void
)
d_element_ops
;
ReducePtrsGlobal
reduce_tuple
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReducePtrsGlobal
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
static_cast
<
T
*>
(
p_reduces
[
I
]);
},
Number
<
NumReduce
>
{});
ReduceInElementwiseOperations
reduce_in_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceInElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_in_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
ReduceAccElementwiseOperations
reduce_out_element_ops
=
generate_tuple
(
[
&
](
auto
I
)
{
auto
tmp
=
ReduceAccElementwiseOperations
{}[
I
];
using
T
=
remove_pointer_t
<
decltype
(
tmp
)
>
;
return
*
(
static_cast
<
T
*>
(
reduce_out_element_op
[
I
]));
},
Number
<
NumReduce
>
{});
AElementwiseOperation
a_element_op
=
*
(
static_cast
<
AElementwiseOperation
*>
(
gemm_element_ops
[
0
]));
BElementwiseOperation
b_element_op
=
*
(
static_cast
<
BElementwiseOperation
*>
(
gemm_element_ops
[
1
]));
CElementwiseOperation
c_element_op
=
*
(
static_cast
<
CElementwiseOperation
*>
(
gemm_element_ops
[
2
]));
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
CDataType
*>
(
p_c
),
static_cast
<
CDataType
*>
(
p_c
),
dxs
_tuple
,
reduce
_tuple
,
M
Raw
,
M
,
N
Raw
,
N
,
K
Raw
,
K
,
StrideA
,
StrideA
,
StrideB
,
StrideB
,
StrideC
,
StrideC
,
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
dxs
_in_element_op
,
reduce
_in_element_op
s
,
dxs
_out_element_op
,
reduce
_out_element_op
s
,
Batch
Count
);
Batch
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
...
@@ -7,7 +10,7 @@
...
@@ -7,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_
batched_
gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
...
@@ -149,7 +152,7 @@ template <typename ADataType,
...
@@ -149,7 +152,7 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceBatchedGemmXdl
struct
DeviceBatchedGemmXdl
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
Device
Batched
Gemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl
...
@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
index_t
Batch
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
Batch
Count
_
(
Batch
Count
),
Batch_
(
Batch
),
a_grid_desc_k0_m_k1_
{
a_grid_desc_k0_m_k1_
{
DeviceBatchedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
)},
DeviceBatchedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
)},
b_grid_desc_k0_n_k1_
{
b_grid_desc_k0_n_k1_
{
...
@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl
...
@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
index_t
Batch
Count
_
;
index_t
Batch_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl
...
@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch
Count
_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch_
;
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
...
@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl
...
@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
...
@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl
...
@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
...
@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl
...
@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
index_t
Batch
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl
...
@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
Batch
Count
};
Batch
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl
...
@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
override
index_t
Batch
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl
...
@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
Batch
Count
);
Batch
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
...
@@ -6,6 +9,7 @@
...
@@ -6,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -22,7 +26,7 @@ template <typename ADataType,
...
@@ -22,7 +26,7 @@ template <typename ADataType,
index_t
AScalarPerVector
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
>
index_t
CScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
BaseOpera
tor
struct
DeviceBinaryElementwise
:
public
DeviceElementwise
<
2
,
1
,
NDim
,
ElementwiseFunc
tor
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
...
@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return
true
;
return
true
;
};
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
virtual
std
::
unique_ptr
<
BaseArgument
>
const
void
*
p_b
,
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
2
>
p_inputs
,
void
*
p_c
,
std
::
array
<
void
*
,
1
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
std
::
vector
<
index_t
>
c_strides
,
ElementwiseFunctor
functor
)
override
ElementwiseFunctor
functor
)
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
inputs
[
0
]
),
static_cast
<
const
BDataType
*>
(
p_
b
),
static_cast
<
const
BDataType
*>
(
p_
inputs
[
1
]
),
static_cast
<
CDataType
*>
(
p_
c
),
static_cast
<
CDataType
*>
(
p_
outputs
[
0
]
),
lengths
,
lengths
,
a
_strides
,
input
_strides
[
0
]
,
b
_strides
,
input
_strides
[
1
]
,
c
_strides
,
output
_strides
[
0
]
,
functor
);
functor
);
}
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
// polymorphic
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
...
@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
...
@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
// clang-format off
str
<<
"DeviceBinaryElementwise"
str
<<
"DeviceBinaryElementwise"
<<
"<"
<<
"<"
<<
"NDim = "
<<
NDim
<<
"MPerThread = "
<<
MPerThread
<<
"MPerThread = "
<<
MPerThread
<<
"AScalarPerVector = "
<<
AScalarPerVector
<<
"BScalarPerVector = "
<<
BScalarPerVector
<<
"CScalarPerVector = "
<<
CScalarPerVector
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
Prev
1
2
3
4
5
6
7
8
9
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment