Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
0b11569f
Commit
0b11569f
authored
Jul 01, 2022
by
Chao Liu
Browse files
Merge remote-tracking branch 'origin/develop' into batched_gemm_c_permute
parents
e8d3a0fb
fa9a0a5c
Changes
554
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
389 additions
and
195 deletions
+389
-195
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
.../ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
.../tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
...ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+3
-0
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
...ration/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
+3
-0
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
...sor_operation/gpu/block/reduction_functions_blockwise.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
...ion/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
+3
-0
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
...ation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
+3
-0
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
...n/gpu/device/convolution_backward_data_specialization.hpp
+3
-0
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
...gpu/device/convolution_backward_weight_specialization.hpp
+3
-0
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
...eration/gpu/device/convolution_forward_specialization.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
...k/tensor_operation/gpu/device/device_5ary_elementwise.hpp
+59
-42
include/ck/tensor_operation/gpu/device/device_base.hpp
include/ck/tensor_operation/gpu/device/device_base.hpp
+3
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
...de/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
+45
-0
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
...on/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
+198
-125
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
...k/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
+15
-12
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
...tensor_operation/gpu/device/device_binary_elementwise.hpp
+27
-16
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_dl_v2r3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v2r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V2R2_HPP
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_dlops_v3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#ifndef CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
#define CK_BLOCKWISE_GEMM_DLOPS_V3_HPP
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/blockwise_tensor_slice_transfer_v5r1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/reduction_functions_blockwise.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/tensor_description/cluster_descriptor.hpp"
#include "ck/tensor_description/cluster_descriptor.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v4r1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r1.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r2.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v6r3.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_v7.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
...
...
include/ck/tensor_operation/gpu/device/convolution_backward_data_specialization.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#ifndef CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
#define CONVOLUTION_BACKWARD_DATA_SPECIALIZATION
...
...
include/ck/tensor_operation/gpu/device/convolution_backward_weight_specialization.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
namespace
ck
{
namespace
ck
{
...
...
include/ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#ifndef CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION
#define CONVOLUTION_FORWARD_SPECIALIZATION
...
...
include/ck/tensor_operation/gpu/device/device_5ary_elementwise.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
...
@@ -7,7 +10,7 @@
...
@@ -7,7 +10,7 @@
#include "ck/utility/common_header.hpp"
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/device_
ba
se.hpp"
#include "ck/tensor_operation/gpu/device/device_
elementwi
se.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_5ary_Elementwise_1d.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/device_utility/kernel_launch.hpp"
...
@@ -32,7 +35,7 @@ template <typename ADataType,
...
@@ -32,7 +35,7 @@ template <typename ADataType,
index_t
DScalarPerVector
,
index_t
DScalarPerVector
,
index_t
EScalarPerVector
,
index_t
EScalarPerVector
,
index_t
FScalarPerVector
>
index_t
FScalarPerVector
>
struct
Device5AryElementwise
:
public
BaseOpera
tor
struct
Device5AryElementwise
:
public
DeviceElementwise
<
5
,
1
,
NDim
,
ElementwiseFunc
tor
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
...
@@ -265,12 +268,8 @@ struct Device5AryElementwise : public BaseOperator
return
true
;
return
true
;
};
};
static
auto
MakeArgument
(
const
ADataType
*
p_a
,
static
auto
MakeArgument
(
std
::
array
<
const
void
*
,
5
>
p_inputs
,
const
BDataType
*
p_b
,
std
::
array
<
void
*
,
1
>
p_outputs
,
const
CDataType
*
p_c
,
const
DDataType
*
p_d
,
const
EDataType
*
p_e
,
FDataType
*
p_f
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
index_t
>
b_strides
,
...
@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
...
@@ -280,12 +279,12 @@ struct Device5AryElementwise : public BaseOperator
std
::
vector
<
index_t
>
f_strides
,
std
::
vector
<
index_t
>
f_strides
,
ElementwiseFunctor
functor
)
ElementwiseFunctor
functor
)
{
{
return
Argument
{
p_a
,
return
Argument
{
static_cast
<
const
ADataType
*>
(
p_inputs
[
0
])
,
p_b
,
static_cast
<
const
BDataType
*>
(
p_inputs
[
1
])
,
p_c
,
static_cast
<
const
CDataType
*>
(
p_inputs
[
2
])
,
p_d
,
static_cast
<
const
DDataType
*>
(
p_inputs
[
3
])
,
p_e
,
static_cast
<
const
EDataType
*>
(
p_inputs
[
4
])
,
p_f
,
static_cast
<
FDataType
*>
(
p_outputs
[
0
])
,
lengths
,
lengths
,
a_strides
,
a_strides
,
b_strides
,
b_strides
,
...
@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
...
@@ -296,40 +295,58 @@ struct Device5AryElementwise : public BaseOperator
functor
};
functor
};
}
}
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
std
::
unique_ptr
<
BaseArgument
>
const
void
*
p_b
,
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
5
>
p_inputs
,
const
void
*
p_c
,
std
::
array
<
void
*
,
1
>
p_outputs
,
const
void
*
p_d
,
std
::
vector
<
index_t
>
lengths
,
const
void
*
p_e
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
void
*
p_f
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
std
::
vector
<
index_t
>
lengths
,
ElementwiseFunctor
functor
)
override
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
index_t
>
c_strides
,
std
::
vector
<
index_t
>
d_strides
,
std
::
vector
<
index_t
>
e_strides
,
std
::
vector
<
index_t
>
f_strides
,
ElementwiseFunctor
functor
)
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
inputs
[
0
]
),
static_cast
<
const
BDataType
*>
(
p_
b
),
static_cast
<
const
BDataType
*>
(
p_
inputs
[
1
]
),
static_cast
<
const
CDataType
*>
(
p_
c
),
static_cast
<
const
CDataType
*>
(
p_
inputs
[
2
]
),
static_cast
<
const
DDataType
*>
(
p_
d
),
static_cast
<
const
DDataType
*>
(
p_
inputs
[
3
]
),
static_cast
<
const
EDataType
*>
(
p_
e
),
static_cast
<
const
EDataType
*>
(
p_
inputs
[
4
]
),
static_cast
<
FDataType
*>
(
p_
f
),
static_cast
<
FDataType
*>
(
p_
outputs
[
0
]
),
lengths
,
lengths
,
a
_strides
,
input
_strides
[
0
]
,
b
_strides
,
input
_strides
[
1
]
,
c
_strides
,
input
_strides
[
2
]
,
d
_strides
,
input
_strides
[
3
]
,
e
_strides
,
input
_strides
[
4
]
,
f
_strides
,
output
_strides
[
0
]
,
functor
);
functor
);
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
};
{
return
std
::
make_unique
<
Invoker
>
();
}
// polymorphic
std
::
string
GetTypeString
()
const
override
{
auto
str
=
std
::
stringstream
();
// clang-format off
str
<<
"Device5aryElementwise"
<<
"<"
<<
"NDim = "
<<
NDim
<<
"MPerThread = "
<<
MPerThread
<<
"AScalarPerVector = "
<<
AScalarPerVector
<<
"BScalarPerVector = "
<<
BScalarPerVector
<<
"CScalarPerVector = "
<<
CScalarPerVector
<<
"DScalarPerVector = "
<<
DScalarPerVector
<<
"EScalarPerVector = "
<<
EScalarPerVector
<<
"FScalarPerVector = "
<<
FScalarPerVector
<<
">"
;
// clang-format on
return
str
.
str
();
}
};
// namespace device
}
// namespace device
}
// namespace device
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/tensor_operation/gpu/device/device_base.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <string>
#include <string>
...
...
include/ck/tensor_operation/gpu/device/device_batched_gemm.hpp
0 → 100644
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <vector>
#include "device_base.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
struct
DeviceBatchedGemm
:
public
BaseOperator
{
virtual
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
const
void
*
p_b
,
void
*
p_c
,
ck
::
index_t
M
,
ck
::
index_t
N
,
ck
::
index_t
K
,
ck
::
index_t
StrideA
,
ck
::
index_t
StrideB
,
ck
::
index_t
StrideC
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
ck
::
index_t
Batch
)
=
0
;
virtual
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
=
0
;
};
template
<
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CElementwiseOperation
>
using
DeviceBatchedGemmPtr
=
std
::
unique_ptr
<
DeviceBatchedGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>>
;
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/device_batched_gemm_reduce_xdl_cshuffle.hpp
View file @
0b11569f
This diff is collapsed.
Click to expand it.
include/ck/tensor_operation/gpu/device/device_batched_gemm_xdl.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
...
@@ -7,7 +10,7 @@
...
@@ -7,7 +10,7 @@
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/device_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_
batched_
gemm.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_xdlops_v2r3.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
...
@@ -149,7 +152,7 @@ template <typename ADataType,
...
@@ -149,7 +152,7 @@ template <typename ADataType,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferSrcDstVectorDim
,
ck
::
index_t
CThreadTransferDstScalarPerVector
>
ck
::
index_t
CThreadTransferDstScalarPerVector
>
struct
DeviceBatchedGemmXdl
struct
DeviceBatchedGemmXdl
:
public
DeviceGemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
:
public
Device
Batched
Gemm
<
AElementwiseOperation
,
BElementwiseOperation
,
CElementwiseOperation
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
...
@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl
...
@@ -336,11 +339,11 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
index_t
Batch
)
:
p_a_grid_
{
p_a_grid
},
:
p_a_grid_
{
p_a_grid
},
p_b_grid_
{
p_b_grid
},
p_b_grid_
{
p_b_grid
},
p_c_grid_
{
p_c_grid
},
p_c_grid_
{
p_c_grid
},
Batch
Count
_
(
Batch
Count
),
Batch_
(
Batch
),
a_grid_desc_k0_m_k1_
{
a_grid_desc_k0_m_k1_
{
DeviceBatchedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
)},
DeviceBatchedGemmXdl
::
MakeAGridDescriptor_K0_M_K1
(
M
,
K
,
StrideA
)},
b_grid_desc_k0_n_k1_
{
b_grid_desc_k0_n_k1_
{
...
@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl
...
@@ -373,7 +376,7 @@ struct DeviceBatchedGemmXdl
const
ADataType
*
p_a_grid_
;
const
ADataType
*
p_a_grid_
;
const
BDataType
*
p_b_grid_
;
const
BDataType
*
p_b_grid_
;
CDataType
*
p_c_grid_
;
CDataType
*
p_c_grid_
;
index_t
Batch
Count
_
;
index_t
Batch_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
AGridDesc_K0_M_K1
a_grid_desc_k0_m_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
BGridDesc_K0_N_K1
b_grid_desc_k0_n_k1_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
CGridDesc_M_N
c_grid_desc_m_n_
;
...
@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl
...
@@ -417,7 +420,7 @@ struct DeviceBatchedGemmXdl
}
}
const
index_t
grid_size
=
const
index_t
grid_size
=
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch
Count
_
;
arg
.
block_2_ctile_map_
.
CalculateGridSize
(
arg
.
c_grid_desc_m_n_
)
*
arg
.
Batch_
;
const
auto
K
=
const
auto
K
=
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I0
)
*
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
);
...
@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl
...
@@ -448,7 +451,7 @@ struct DeviceBatchedGemmXdl
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
...
@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl
...
@@ -482,7 +485,7 @@ struct DeviceBatchedGemmXdl
arg
.
p_a_grid_
,
arg
.
p_a_grid_
,
arg
.
p_b_grid_
,
arg
.
p_b_grid_
,
arg
.
p_c_grid_
,
arg
.
p_c_grid_
,
arg
.
Batch
Count
_
,
arg
.
Batch_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
a_grid_desc_k0_m_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
b_grid_desc_k0_n_k1_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
arg
.
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_
,
...
@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl
...
@@ -536,7 +539,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
index_t
Batch
)
{
{
return
Argument
{
p_a
,
return
Argument
{
p_a
,
p_b
,
p_b
,
...
@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl
...
@@ -552,7 +555,7 @@ struct DeviceBatchedGemmXdl
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
Batch
Count
};
Batch
};
}
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
static
auto
MakeInvoker
()
{
return
Invoker
{};
}
...
@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl
...
@@ -570,7 +573,7 @@ struct DeviceBatchedGemmXdl
AElementwiseOperation
a_element_op
,
AElementwiseOperation
a_element_op
,
BElementwiseOperation
b_element_op
,
BElementwiseOperation
b_element_op
,
CElementwiseOperation
c_element_op
,
CElementwiseOperation
c_element_op
,
index_t
Batch
Count
)
override
index_t
Batch
)
override
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_a
),
static_cast
<
const
BDataType
*>
(
p_b
),
static_cast
<
const
BDataType
*>
(
p_b
),
...
@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl
...
@@ -586,7 +589,7 @@ struct DeviceBatchedGemmXdl
a_element_op
,
a_element_op
,
b_element_op
,
b_element_op
,
c_element_op
,
c_element_op
,
Batch
Count
);
Batch
);
}
}
// polymorphic
// polymorphic
...
...
include/ck/tensor_operation/gpu/device/device_binary_elementwise.hpp
View file @
0b11569f
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2022, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include <iostream>
#include <iostream>
...
@@ -6,6 +9,7 @@
...
@@ -6,6 +9,7 @@
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/device_prop.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/device_utility/kernel_launch.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/tensor_operation/gpu/device/device_elementwise.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_binary_elementwise_1d.hpp"
namespace
ck
{
namespace
ck
{
...
@@ -22,7 +26,7 @@ template <typename ADataType,
...
@@ -22,7 +26,7 @@ template <typename ADataType,
index_t
AScalarPerVector
,
index_t
AScalarPerVector
,
index_t
BScalarPerVector
,
index_t
BScalarPerVector
,
index_t
CScalarPerVector
>
index_t
CScalarPerVector
>
struct
DeviceBinaryElementwise
:
public
BaseOpera
tor
struct
DeviceBinaryElementwise
:
public
DeviceElementwise
<
2
,
1
,
NDim
,
ElementwiseFunc
tor
>
{
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I0
=
Number
<
0
>
{};
...
@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
...
@@ -195,27 +199,30 @@ struct DeviceBinaryElementwise : public BaseOperator
return
true
;
return
true
;
};
};
std
::
unique_ptr
<
BaseArgument
>
MakeArgumentPointer
(
const
void
*
p_a
,
virtual
std
::
unique_ptr
<
BaseArgument
>
const
void
*
p_b
,
MakeArgumentPointer
(
std
::
array
<
const
void
*
,
2
>
p_inputs
,
void
*
p_c
,
std
::
array
<
void
*
,
1
>
p_outputs
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
lengths
,
std
::
vector
<
index_t
>
a_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
input_strides
,
std
::
vector
<
index_t
>
b_strides
,
std
::
vector
<
std
::
vector
<
index_t
>>
output_strides
,
std
::
vector
<
index_t
>
c_strides
,
ElementwiseFunctor
functor
)
override
ElementwiseFunctor
functor
)
{
{
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
a
),
return
std
::
make_unique
<
Argument
>
(
static_cast
<
const
ADataType
*>
(
p_
inputs
[
0
]
),
static_cast
<
const
BDataType
*>
(
p_
b
),
static_cast
<
const
BDataType
*>
(
p_
inputs
[
1
]
),
static_cast
<
CDataType
*>
(
p_
c
),
static_cast
<
CDataType
*>
(
p_
outputs
[
0
]
),
lengths
,
lengths
,
a
_strides
,
input
_strides
[
0
]
,
b
_strides
,
input
_strides
[
1
]
,
c
_strides
,
output
_strides
[
0
]
,
functor
);
functor
);
}
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
{
return
std
::
make_unique
<
Invoker
>
();
}
std
::
unique_ptr
<
BaseInvoker
>
MakeInvokerPointer
()
override
{
return
std
::
make_unique
<
Invoker
>
();
}
// polymorphic
std
::
string
GetTypeString
()
const
override
std
::
string
GetTypeString
()
const
override
{
{
auto
str
=
std
::
stringstream
();
auto
str
=
std
::
stringstream
();
...
@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
...
@@ -223,7 +230,11 @@ struct DeviceBinaryElementwise : public BaseOperator
// clang-format off
// clang-format off
str
<<
"DeviceBinaryElementwise"
str
<<
"DeviceBinaryElementwise"
<<
"<"
<<
"<"
<<
"NDim = "
<<
NDim
<<
"MPerThread = "
<<
MPerThread
<<
"MPerThread = "
<<
MPerThread
<<
"AScalarPerVector = "
<<
AScalarPerVector
<<
"BScalarPerVector = "
<<
BScalarPerVector
<<
"CScalarPerVector = "
<<
CScalarPerVector
<<
">"
;
<<
">"
;
// clang-format on
// clang-format on
...
...
Prev
1
2
3
4
5
6
7
8
9
…
28
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment