Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
6bbb94f4
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "fd877fabf21caea741b8ad6a63e96bfb2fe395e5"
Unverified
Commit
6bbb94f4
authored
Jun 26, 2024
by
Rostyslav Geyyer
Committed by
GitHub
Jun 26, 2024
Browse files
Merge branch 'develop' into lwpck-1815
parents
4c850c90
a32b1bc6
Changes
107
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1904 additions
and
73 deletions
+1904
-73
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+11
-14
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+11
-14
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+7
-9
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+10
-13
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+4
-5
include/ck/tensor_operation/gpu/device/helper.hpp
include/ck/tensor_operation/gpu/device/helper.hpp
+359
-0
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+781
-0
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
+13
-0
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+18
-0
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
...eration/operator_transform/transform_conv_fwd_to_gemm.hpp
+564
-0
include/ck/utility/amd_buffer_addressing.hpp
include/ck/utility/amd_buffer_addressing.hpp
+2
-1
include/ck/utility/array.hpp
include/ck/utility/array.hpp
+2
-0
include/ck_tile/core.hpp
include/ck_tile/core.hpp
+1
-0
include/ck_tile/core/arch/amd_buffer_addressing.hpp
include/ck_tile/core/arch/amd_buffer_addressing.hpp
+8
-2
include/ck_tile/core/numeric/null_type.hpp
include/ck_tile/core/numeric/null_type.hpp
+13
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+1
-0
include/ck_tile/host/check_err.hpp
include/ck_tile/host/check_err.hpp
+15
-10
include/ck_tile/host/reference/reference_layernorm2d.hpp
include/ck_tile/host/reference/reference_layernorm2d.hpp
+69
-0
include/ck_tile/host/timer.hpp
include/ck_tile/host/timer.hpp
+5
-5
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+10
-0
No files found.
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -403,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -403,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
@@ -472,10 +472,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -472,10 +472,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
@@ -529,10 +528,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -529,10 +528,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
@@ -562,10 +560,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
...
@@ -562,10 +560,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -444,7 +444,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -444,7 +444,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
@@ -513,10 +513,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -513,10 +513,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
a_block_desc_m0_m1_m2_k
,
...
@@ -564,10 +563,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -564,10 +563,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
a_thread_copy_
.
Run
(
a_thread_copy_
.
Run
(
...
@@ -607,10 +605,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
...
@@ -607,10 +605,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -352,10 +352,9 @@ struct BlockwiseGemmWMMA
...
@@ -352,10 +352,9 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
@@ -411,10 +410,9 @@ struct BlockwiseGemmWMMA
...
@@ -411,10 +410,9 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
wmma_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -340,10 +340,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -340,10 +340,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
@@ -530,10 +529,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
...
@@ -530,10 +529,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// TODO: insert setprio in more precise manner since we
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -963,10 +961,9 @@ struct BlockwiseGemmXdlops_v2
...
@@ -963,10 +961,9 @@ struct BlockwiseGemmXdlops_v2
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
...
@@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
constexpr
index_t
c_offset
=
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
});
});
});
...
...
include/ck/tensor_operation/gpu/device/helper.hpp
0 → 100644
View file @
6bbb94f4
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include <fstream>
#include <variant>
// functions to return the corresponding structs based on generated template parameters
using
layouts
=
std
::
variant
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
NHWGK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWGK
>
;
// return the layout type: currently this is the only type supported in MIOpen
auto
layout_type
(
std
::
string
type
)
{
if
(
type
==
"ck::tensor_layout::convolution::NHWGK"
)
{
return
ck
::
tensor_layout
::
convolution
::
NHWGK
{};
}
throw
std
::
runtime_error
(
"Incorrect layout"
);
}
// return the right gemm spec based on the generated template parameters
ck
::
tensor_operation
::
device
::
GemmSpecialization
gemm_type
(
std
::
string
type
)
{
if
(
type
==
"ck::tensor_operation::device::GemmSpecialization::Default"
)
{
return
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
}
if
(
type
==
"ck::tensor_operation::device::GemmSpecialization::MNKPadding"
)
{
return
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
}
throw
std
::
runtime_error
(
"Incorrect gemm spec: "
+
type
);
}
// return the type of convolution
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
conv_type
(
std
::
string
type
)
{
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Default"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
}
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
}
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
}
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
}
throw
std
::
runtime_error
(
"Incorrect conv spec: "
+
type
);
}
// Function to call on MatrixPadder via a wrapper struct
// NOTE: CK only uses MNKPadding for forward convolution
template
<
typename
CDesc_MRaw_NRaw
>
auto
pad
(
ck
::
index_t
mpb
,
ck
::
index_t
npb
,
ck
::
index_t
kpb
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
gemm
,
CDesc_MRaw_NRaw
conv
)
{
if
(
gemm
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
ck
::
tensor_operation
::
device
::
MatrixPadder
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
a
;
a
.
MPerTile_
=
mpb
;
a
.
NPerTile_
=
npb
;
a
.
KPerTile_
=
kpb
;
auto
tmp
=
grid_desc
(
a
,
conv
);
return
tmp
;
}
throw
std
::
runtime_error
(
"Incorrect template parameters, check gemm spec"
);
}
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
// dims
// FIXME: add a way to properly pass in the layout
auto
transform_conv
(
ck
::
index_t
num_dim
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
spec
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_strides
)
{
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
throw
std
::
runtime_error
(
"Incorrect conv spec"
);
}
auto
transform_conv_3d
(
ck
::
index_t
num_dim
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
spec
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_strides
)
{
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
throw
std
::
runtime_error
(
"Incorrect conv spec"
);
}
auto
transform_conv_1d
(
ck
::
index_t
num_dim
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
spec
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_strides
)
{
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
throw
std
::
runtime_error
(
"Incorrect dims or conv spec"
);
}
template
<
typename
CGridDesc_M_N
>
auto
block_2_etile
(
ck
::
index_t
m_per_block
,
ck
::
index_t
n_per_block
,
CGridDesc_M_N
matrix_padder
)
{
if
(
m_per_block
==
32
&&
n_per_block
==
64
)
{
auto
b2e
=
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
32
,
64
,
CGridDesc_M_N
>
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
32
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
32
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
64
&&
n_per_block
==
32
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
64
,
32
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
64
&&
n_per_block
==
64
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
64
,
64
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
64
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
64
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
32
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
32
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
64
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
64
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
256
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
256
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
256
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
256
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
throw
std
::
runtime_error
(
"Incorrect template parameters"
);
}
// wrapper functions by dims to get grid size - uses above 3 functions
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
auto
get_launch_params_1d
(
ck
::
host
::
Solution
solution
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_strides
)
{
auto
num_dim
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NumDim"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NPerBlock"
);
auto
k_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"KPerBlock"
);
auto
GemmType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"GemmSpecialization"
);
auto
ConvType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"ConvSpecialization"
);
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
=
gemm_type
(
GemmType
);
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
ConvSpec
=
conv_type
(
ConvType
);
auto
conv_to_gemm_transformer
=
transform_conv_1d
(
num_dim
,
ConvSpec
,
out_lengths
,
out_strides
);
auto
matrix_padder
=
pad
(
m_per_block
,
n_per_block
,
k_per_block
,
GemmSpec
,
conv_to_gemm_transformer
);
auto
b2e
=
block_2_etile
(
m_per_block
,
n_per_block
,
matrix_padder
);
return
b2e
;
}
auto
get_launch_params
(
ck
::
host
::
Solution
solution
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_strides
)
{
auto
num_dim
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NumDim"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NPerBlock"
);
auto
k_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"KPerBlock"
);
auto
GemmType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"GemmSpecialization"
);
auto
ConvType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"ConvSpecialization"
);
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
=
gemm_type
(
GemmType
);
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
ConvSpec
=
conv_type
(
ConvType
);
auto
conv_to_gemm_transformer
=
transform_conv
(
num_dim
,
ConvSpec
,
out_lengths
,
out_strides
);
auto
matrix_padder
=
pad
(
m_per_block
,
n_per_block
,
k_per_block
,
GemmSpec
,
conv_to_gemm_transformer
);
auto
b2e
=
block_2_etile
(
m_per_block
,
n_per_block
,
matrix_padder
);
return
b2e
;
}
auto
get_launch_params_3d
(
ck
::
host
::
Solution
solution
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_strides
)
{
auto
num_dim
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NumDim"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NPerBlock"
);
auto
k_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"KPerBlock"
);
auto
GemmType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"GemmSpecialization"
);
auto
ConvType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"ConvSpecialization"
);
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
=
gemm_type
(
GemmType
);
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
ConvSpec
=
conv_type
(
ConvType
);
auto
conv_to_gemm_transformer
=
transform_conv_3d
(
num_dim
,
ConvSpec
,
out_lengths
,
out_strides
);
auto
matrix_padder
=
pad
(
m_per_block
,
n_per_block
,
k_per_block
,
GemmSpec
,
conv_to_gemm_transformer
);
auto
b2e
=
block_2_etile
(
m_per_block
,
n_per_block
,
matrix_padder
);
return
b2e
;
}
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
// tuples if multi AB, pointers if no
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
bool
isMultiA
,
bool
isMultiB
>
__device__
void
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
&
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
DsPointer
p_ds_grid_grp
;
static
constexpr
index_t
NumDTensor
=
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_batch_offset
=
compute_ptr_offset_of_batch
.
GetAsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_batch_offset
[
i
];
});
const
auto
&
bs_batch_offset
=
compute_ptr_offset_of_batch
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
}
else
{
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_batch_offset
,
p_bs_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
}
#else
ignore
=
p_as_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_ctile_map
;
#endif
}
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
// tuples if multi AB, pointers if no
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
bool
isMultiA
,
bool
isMultiB
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
AsPointer
,
// tuples if multi AB, pointers if no
BsPointer
,
DsPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfBatch
,
HasMainKBlockLoop
,
isMultiA
,
isMultiB
>
(
p_as_grid
,
p_bs_grid
,
p_ds_grid
,
*
p_e_grid
,
a_element_op
,
b_element_op
,
cde_element_op
,
batch_count
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
,
compute_ptr_offset_of_batch
);
}
}
// namespace
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
struct
CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputeDataType
>
{
using
DeviceOp
=
CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
return
in_gemmm_gemmk_desc
;
}
template
<
typename
BLay
>
__host__
__device__
static
auto
MakeBGridDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
ELay
>
__host__
__device__
static
auto
MakeEGridDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
return
out_gemmm_gemmn_desc
;
}
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
using
GemmADataType
=
std
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
using
GridwiseGemm
=
std
::
conditional_t
<
isMultiA
||
isMultiB
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
using
AGridPointer
=
remove_cvref_t
<
decltype
(
GetAGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
ADataType
>
())
>
;
using
BGridPointer
=
remove_cvref_t
<
decltype
(
GetBGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
BDataType
>
())
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
{
__device__
__host__
Argument
(
APointers
p_as
,
BPointers
p_bs
,
const
ck
::
Array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_as_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// A/B/E Batch Stride
if
constexpr
(
isMultiA
||
isMultiB
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmADataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiA
)
{
// p_as is tuple
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
[
i
.
value
]);
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiB
)
{
// p_bs is tuple
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
[
i
.
value
]);
}
else
{
// if MultiA and not MultiB then p_bs is single pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
);
}
});
}
else
{
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_bs_grid_
(
I0
)
=
static_cast
<
const
BDataType
*>
(
p_bs
);
}
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]);
});
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
// populate desc for Ds/E
if
constexpr
(
isMultiA
||
isMultiB
)
{
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
([
&
](
auto
)
{
return
a_grid_desc_m_k_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
([
&
](
auto
)
{
return
b_grid_desc_n_k_
;
},
Number
<
NumBTensor
>
{});
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
}
else
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
}
}
// private:
// pointers (tuple if multi AB, pointer if no)
AGridPointer
p_as_grid_
;
BGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
static
__device__
__host__
auto
MakeArgument
(
APointers
p_as
,
BPointers
p_bs
,
const
ck
::
Array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/matrix_padder.hpp
View file @
6bbb94f4
...
@@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
...
@@ -180,6 +180,19 @@ struct MatrixPadder : public GemmPadder<GemmSpec, MPerTileType, NPerTileType, KP
{
{
};
};
// function to take in a struct of type MatrixPadder and call the appropriate function to get
// the output descriptor at runtime for codegen
template
<
GemmSpecialization
GemmSpec
,
typename
MPerTileType
,
typename
NPerTileType
,
typename
KPerTileType
,
typename
CDesc_MRaw_NRaw
>
auto
grid_desc
(
MatrixPadder
<
GemmSpec
,
MPerTileType
,
NPerTileType
,
KPerTileType
>
matrix_padder
,
CDesc_MRaw_NRaw
conv_desc
)
{
auto
res
=
matrix_padder
.
PadCDescriptor_M_N
(
conv_desc
);
return
res
;
}
// M/N/KPerTileType could be index_t or Number<>
// M/N/KPerTileType could be index_t or Number<>
template
<
bool
PadM
,
template
<
bool
PadM
,
bool
PadN
,
bool
PadN
,
...
...
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
6bbb94f4
...
@@ -961,6 +961,24 @@ struct Elu
...
@@ -961,6 +961,24 @@ struct Elu
const
float
alpha_
;
const
float
alpha_
;
};
};
struct
Logistic
{
Logistic
(
float
alpha
=
1.
f
)
:
alpha_
(
alpha
){};
template
<
typename
T
>
__host__
__device__
void
operator
()(
T
&
y
,
const
T
&
x
)
const
{
static_assert
(
is_same
<
T
,
float
>::
value
||
is_same
<
T
,
double
>::
value
||
is_same
<
T
,
half_t
>::
value
||
is_same
<
T
,
int32_t
>::
value
||
is_same
<
T
,
int8_t
>::
value
,
"Data type is not supported by this operation!"
);
T
casted_alpha
=
type_convert
<
T
>
(
alpha_
);
constexpr
T
one
=
type_convert
<
T
>
(
1
);
y
=
casted_alpha
/
(
one
+
ck
::
math
::
exp
(
-
x
)
*
casted_alpha
);
}
const
float
alpha_
;
};
struct
ConvInvscale
struct
ConvInvscale
{
{
__host__
__device__
ConvInvscale
(
float
scale_in
=
1.
f
,
__host__
__device__
ConvInvscale
(
float
scale_in
=
1.
f
,
...
...
include/ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp
View file @
6bbb94f4
...
@@ -14,6 +14,17 @@
...
@@ -14,6 +14,17 @@
namespace
ck
{
namespace
ck
{
namespace
tensor_operation
{
namespace
tensor_operation
{
// function to be used on device, emulates std::accumulate
template
<
typename
T
,
typename
ForwardIterator
,
typename
Size
>
__host__
__device__
auto
mult_accumulate_n
(
ForwardIterator
first
,
Size
count
,
T
init
)
{
for
(
ForwardIterator
x
=
first
;
x
!=
first
+
count
;
x
++
)
{
init
*=
*
x
;
}
return
init
;
}
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
struct
TransformConvFwdToGemm
struct
TransformConvFwdToGemm
{
{
...
@@ -607,6 +618,559 @@ struct TransformConvFwdToGemm
...
@@ -607,6 +618,559 @@ struct TransformConvFwdToGemm
return
out_gemmm_gemmn_desc
;
return
out_gemmm_gemmn_desc
;
}
}
// Overloaded functions for hipRTC purposes
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
1
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
0
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
X
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
0
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
3
];
const
auto
CStride
=
I1
;
const
auto
in_n_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Wi
,
C
),
make_tuple
(
NStride
,
WiStride
,
CStride
));
const
auto
in_n_wip_c_desc
=
transform_tensor_descriptor
(
in_n_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}));
const
auto
in_n_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Wo
)),
make_merge_transform
(
make_tuple
(
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
>
{},
Sequence
<
1
,
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
2
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNHWC
>
),
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
1
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Y
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
1
];
const
index_t
InLeftPadH
=
input_left_pads
[
0
];
const
index_t
InLeftPadW
=
input_left_pads
[
1
];
const
index_t
InRightPadH
=
input_right_pads
[
0
];
const
index_t
InRightPadW
=
input_right_pads
[
1
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
4
];
const
auto
CStride
=
I1
;
const
auto
in_n_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}));
const
auto
in_n_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
>
{},
Sequence
<
1
,
3
,
5
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
ALayout
,
typename
std
::
enable_if
<
NDimSpatial
==
3
&&
(
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
G_NDHW_C
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
NDHWGC
>
||
is_same_v
<
ALayout
,
tensor_layout
::
convolution
::
GNDHWC
>
),
bool
>::
type
=
false
>
static
auto
MakeADescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
index_t
N
=
a_g_n_c_wis_lengths
[
1
];
const
index_t
C
=
a_g_n_c_wis_lengths
[
2
];
const
index_t
Di
=
a_g_n_c_wis_lengths
[
3
];
const
index_t
Hi
=
a_g_n_c_wis_lengths
[
4
];
const
index_t
Wi
=
a_g_n_c_wis_lengths
[
5
];
const
index_t
Do
=
c_g_n_k_wos_lengths
[
3
];
const
index_t
Ho
=
c_g_n_k_wos_lengths
[
4
];
const
index_t
Wo
=
c_g_n_k_wos_lengths
[
5
];
const
index_t
ConvStrideD
=
conv_filter_strides
[
0
];
const
index_t
ConvStrideH
=
conv_filter_strides
[
1
];
const
index_t
ConvStrideW
=
conv_filter_strides
[
2
];
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
const
index_t
NDoHoWo
=
N
*
ck
::
accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
,
std
::
multiplies
<>
());
// This is different
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
in_gemmm_gemmk_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NDoHoWo
,
C
),
make_tuple
(
WiStride
,
CStride
));
return
in_gemmm_gemmk_desc
;
}
else
if
constexpr
(
ConvForwardSpecialization
==
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_do_ho_wo_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Do
),
make_tuple
(
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Ho
),
make_tuple
(
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
Wo
),
make_tuple
(
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_do_ho_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
,
1
,
2
,
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
else
{
const
index_t
Z
=
b_g_k_c_xs_lengths
[
3
];
const
index_t
Y
=
b_g_k_c_xs_lengths
[
4
];
const
index_t
X
=
b_g_k_c_xs_lengths
[
5
];
const
index_t
ConvDilationD
=
conv_filter_dilations
[
0
];
const
index_t
ConvDilationH
=
conv_filter_dilations
[
1
];
const
index_t
ConvDilationW
=
conv_filter_dilations
[
2
];
const
index_t
InLeftPadD
=
input_left_pads
[
0
];
const
index_t
InLeftPadH
=
input_left_pads
[
1
];
const
index_t
InLeftPadW
=
input_left_pads
[
2
];
const
index_t
InRightPadD
=
input_right_pads
[
0
];
const
index_t
InRightPadH
=
input_right_pads
[
1
];
const
index_t
InRightPadW
=
input_right_pads
[
2
];
// This is different
const
index_t
NStride
=
a_g_n_c_wis_strides
[
1
];
const
index_t
DiStride
=
a_g_n_c_wis_strides
[
3
];
const
index_t
HiStride
=
a_g_n_c_wis_strides
[
4
];
const
index_t
WiStride
=
a_g_n_c_wis_strides
[
5
];
const
auto
CStride
=
I1
;
const
auto
in_n_di_hi_wi_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
N
,
Di
,
Hi
,
Wi
,
C
),
make_tuple
(
NStride
,
DiStride
,
HiStride
,
WiStride
,
CStride
));
const
auto
in_n_hip_wip_c_desc
=
transform_tensor_descriptor
(
in_n_di_hi_wi_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_pad_transform
(
Di
,
InLeftPadD
,
InRightPadD
),
make_pad_transform
(
Hi
,
InLeftPadH
,
InRightPadH
),
make_pad_transform
(
Wi
,
InLeftPadW
,
InRightPadW
),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}));
const
auto
in_n_z_do_y_ho_x_wo_c_desc
=
transform_tensor_descriptor
(
in_n_hip_wip_c_desc
,
make_tuple
(
make_pass_through_transform
(
N
),
make_embed_transform
(
make_tuple
(
Z
,
Do
),
make_tuple
(
ConvDilationD
,
ConvStrideD
)),
make_embed_transform
(
make_tuple
(
Y
,
Ho
),
make_tuple
(
ConvDilationH
,
ConvStrideH
)),
make_embed_transform
(
make_tuple
(
X
,
Wo
),
make_tuple
(
ConvDilationW
,
ConvStrideW
)),
make_pass_through_transform
(
C
)),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{},
Sequence
<
2
>
{},
Sequence
<
3
>
{},
Sequence
<
4
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{},
Sequence
<
3
,
4
>
{},
Sequence
<
5
,
6
>
{},
Sequence
<
7
>
{}));
const
auto
in_gemmm_gemmk_desc
=
transform_tensor_descriptor
(
in_n_z_do_y_ho_x_wo_c_desc
,
make_tuple
(
make_merge_transform
(
make_tuple
(
N
,
Do
,
Ho
,
Wo
)),
make_merge_transform
(
make_tuple
(
Z
,
Y
,
X
,
C
))),
make_tuple
(
Sequence
<
0
,
2
,
4
,
6
>
{},
Sequence
<
1
,
3
,
5
,
7
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
in_gemmm_gemmk_desc
;
}
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKYXC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
GKZYXC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* b_g_k_c_xs_strides */
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
wei_gemmn_gemmk_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
K
,
YX
*
C
));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
BLayout
,
typename
std
::
enable_if
<
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_X_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_YX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
G_K_ZYX_C
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KYXGC
>
||
is_same_v
<
BLayout
,
tensor_layout
::
convolution
::
KZYXGC
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeBDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
index_t
K
=
b_g_k_c_xs_lengths
[
1
];
const
index_t
C
=
b_g_k_c_xs_lengths
[
2
];
const
index_t
YX
=
mult_accumulate_n
<
index_t
>
(
b_g_k_c_xs_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
index_t
KStride
=
b_g_k_c_xs_strides
[
1
];
const
index_t
XStride
=
b_g_k_c_xs_strides
[
2
+
NDimSpatial
];
const
auto
CStride
=
I1
;
const
auto
wei_k_yx_c_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
K
,
YX
,
C
),
make_tuple
(
KStride
,
XStride
,
CStride
));
const
auto
wei_gemmn_gemmk_desc
=
transform_tensor_descriptor
(
wei_k_yx_c_desc
,
make_tuple
(
make_pass_through_transform
(
K
),
make_merge_transform
(
make_tuple
(
YX
,
C
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}));
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNHWK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
GNDHWK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
/* c_g_n_k_wos_strides */
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
NHoWo
,
K
));
return
out_gemmm_gemmn_desc
;
}
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_NDHW_K
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NHWGK
>
||
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
NDHWGK
>
,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
auto
KStride
=
I1
;
const
index_t
WoStride
=
c_g_n_k_wos_strides
[
NDimSpatial
+
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
WoStride
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
// for output bias
template
<
typename
CLayout
,
typename
std
::
enable_if
<
is_same_v
<
CLayout
,
tensor_layout
::
convolution
::
G_K
>,
bool
>::
type
=
false
>
__host__
__device__
static
auto
MakeCDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
c_g_n_k_wos_strides
)
{
const
index_t
N
=
c_g_n_k_wos_lengths
[
1
];
const
index_t
K
=
c_g_n_k_wos_lengths
[
2
];
const
index_t
KStride
=
c_g_n_k_wos_strides
[
2
];
const
index_t
NHoWo
=
N
*
mult_accumulate_n
<
index_t
>
(
c_g_n_k_wos_lengths
.
begin
()
+
3
,
NDimSpatial
,
1
);
const
auto
out_gemmm_gemmn_desc
=
make_naive_tensor_descriptor
(
make_tuple
(
NHoWo
,
K
),
make_tuple
(
I0
,
KStride
));
return
out_gemmm_gemmn_desc
;
}
};
// wrapper class to call member functions on TransformConvToGemm struct at runtime
// TODO: figure out aq way to properly pass in layout as an argument
struct
TransformConv
{
TransformConv
()
{}
template
<
index_t
NDimSpatial
,
device
::
ConvolutionForwardSpecialization
ConvForwardSpecialization
>
auto
transform_func
(
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_lengths
,
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
out_strides
,
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
conv_fwd_to_gemm
)
{
if
(
NDimSpatial
==
2
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
ck
::
tensor_layout
::
convolution
::
NHWGK
>(
out_lengths
,
out_strides
);
}
else
if
(
NDimSpatial
==
3
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NDHWGK
>(
out_lengths
,
out_strides
);
}
else
if
(
NDimSpatial
==
1
)
{
return
conv_fwd_to_gemm
.
template
MakeCDescriptor_M_N
<
tensor_layout
::
convolution
::
NWGK
>(
out_lengths
,
out_strides
);
}
}
};
};
}
// namespace tensor_operation
}
// namespace tensor_operation
...
...
include/ck/utility/amd_buffer_addressing.hpp
View file @
6bbb94f4
...
@@ -991,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
...
@@ -991,7 +991,8 @@ __device__ void amd_direct_load_global_to_lds(const T* global_base_ptr,
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
"s"
(
src_resource
)
:
"memory"
);
#else
#else
// LDS pointer must be attributed with the LDS address space.
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
...
...
include/ck/utility/array.hpp
View file @
6bbb94f4
...
@@ -36,6 +36,8 @@ struct Array
...
@@ -36,6 +36,8 @@ struct Array
return
*
this
;
return
*
this
;
}
}
__host__
__device__
constexpr
const
TData
*
begin
()
const
{
return
&
mData
[
0
];
}
__host__
__device__
constexpr
const
TData
*
end
()
const
{
return
&
mData
[
NSize
];
}
};
};
// empty Array
// empty Array
...
...
include/ck_tile/core.hpp
View file @
6bbb94f4
...
@@ -27,6 +27,7 @@
...
@@ -27,6 +27,7 @@
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integer.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/integral_constant.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/math.hpp"
#include "ck_tile/core/numeric/null_type.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/numeric.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/type_convert.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
#include "ck_tile/core/numeric/vector_type.hpp"
...
...
include/ck_tile/core/arch/amd_buffer_addressing.hpp
View file @
6bbb94f4
...
@@ -26,7 +26,12 @@ struct __attribute__((packed)) buffer_resource
...
@@ -26,7 +26,12 @@ struct __attribute__((packed)) buffer_resource
CK_TILE_DEVICE
int32x4_t
make_wave_buffer_resource
(
const
void
*
ptr
,
uint32_t
size
=
0xffffffff
)
CK_TILE_DEVICE
int32x4_t
make_wave_buffer_resource
(
const
void
*
ptr
,
uint32_t
size
=
0xffffffff
)
{
{
buffer_resource
res
{
ptr
,
size
,
CK_TILE_BUFFER_RESOURCE_3RD_DWORD
};
buffer_resource
res
{
ptr
,
size
,
CK_TILE_BUFFER_RESOURCE_3RD_DWORD
};
return
__builtin_bit_cast
(
int32x4_t
,
res
);
int32x4_t
r
=
__builtin_bit_cast
(
int32x4_t
,
res
);
r
.
x
=
__builtin_amdgcn_readfirstlane
(
r
.
x
);
r
.
y
=
__builtin_amdgcn_readfirstlane
(
r
.
y
);
r
.
z
=
__builtin_amdgcn_readfirstlane
(
r
.
z
);
r
.
w
=
__builtin_amdgcn_readfirstlane
(
r
.
w
);
return
r
;
}
}
namespace
impl
{
namespace
impl
{
...
@@ -2104,7 +2109,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
...
@@ -2104,7 +2109,8 @@ CK_TILE_DEVICE void amd_direct_load_global_to_lds(const T* global_base_ptr,
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
asm
volatile
(
"s_mov_b32 m0, %0;
\n\t
"
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"buffer_load_dword %1, %2, 0 offen lds;
\n\t
"
::
"s"
(
lds_ptr_sgpr
),
"v"
(
global_offset_bytes
),
"v"
(
global_offset_bytes
),
"s"
(
src_resource
));
"s"
(
src_resource
)
:
"memory"
);
#else
#else
// LDS pointer must be attributed with the LDS address space.
// LDS pointer must be attributed with the LDS address space.
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
__attribute__
((
address_space
(
3
)))
uint32_t
*
lds_ptr
=
...
...
include/ck_tile/core/numeric/null_type.hpp
0 → 100644
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <stdint.h>
namespace
ck_tile
{
struct
null_type
{
};
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
6bbb94f4
...
@@ -18,6 +18,7 @@
...
@@ -18,6 +18,7 @@
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_reduce.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/reference/reference_softmax.hpp"
#include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/stream_config.hpp"
...
...
include/ck_tile/host/check_err.hpp
View file @
6bbb94f4
...
@@ -56,8 +56,9 @@ check_err(const Range& out,
...
@@ -56,8 +56,9 @@ check_err(const Range& out,
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -114,8 +115,9 @@ check_err(const Range& out,
...
@@ -114,8 +115,9 @@ check_err(const Range& out,
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -173,8 +175,9 @@ check_err(const Range& out,
...
@@ -173,8 +175,9 @@ check_err(const Range& out,
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -285,8 +288,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
...
@@ -357,8 +361,9 @@ std::enable_if_t<(std::is_same_v<ranges::range_value_t<Range>, ranges::range_val
}
}
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
auto
is_infinity_error
=
[
=
](
auto
o
,
auto
r
)
{
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
either_not_finite
=
!
std
::
isfinite
(
o
)
||
!
std
::
isfinite
(
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
o
==
r
);
const
bool
both_infinite_and_same
=
std
::
isinf
(
o
)
&&
std
::
isinf
(
r
)
&&
(
bit_cast
<
uint64_t
>
(
o
)
==
bit_cast
<
uint64_t
>
(
r
));
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
return
either_not_finite
&&
!
(
allow_infinity_ref
&&
both_infinite_and_same
);
};
};
...
...
include/ck_tile/host/reference/reference_layernorm2d.hpp
0 → 100644
View file @
6bbb94f4
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
void
reference_layernorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
BetaDataType
>&
beta_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
MeanDataType
>&
mean_m
,
HostTensor
<
InvStdDataType
>&
invStd_m
,
ComputeDataType
epsilon
)
{
auto
layernorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
int
count
=
0
;
ComputeDataType
mean
=
0
;
ComputeDataType
variance
=
0
;
ComputeDataType
divisor
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
++
count
;
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
delta
=
x
-
mean
;
mean
+=
delta
/
count
;
ComputeDataType
delta2
=
x
-
mean
;
variance
+=
delta
*
delta2
;
}
// actual variance
variance
=
variance
/
count
;
divisor
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
1
)
/
ck_tile
::
sqrt
(
variance
+
epsilon
);
if
constexpr
(
!
std
::
is_same_v
<
MeanDataType
,
ck_tile
::
null_type
>
)
mean_m
(
m
)
=
ck_tile
::
type_convert
<
MeanDataType
>
(
mean
);
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
auto
y
=
(
x
-
mean
)
*
divisor
;
y
=
y
*
gamma
+
beta
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
);
}
};
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
mean_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/timer.hpp
View file @
6bbb94f4
...
@@ -27,7 +27,7 @@ struct gpu_timer
...
@@ -27,7 +27,7 @@ struct gpu_timer
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
HIP_CHECK_ERROR
(
hipEventRecord
(
start_evt
,
s
));
}
}
...
@@ -51,15 +51,15 @@ struct gpu_timer
...
@@ -51,15 +51,15 @@ struct gpu_timer
struct
cpu_timer
struct
cpu_timer
{
{
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
start
(
const
hipStream_t
&
)
CK_TILE_HOST
void
start
(
const
hipStream_t
&
s
)
{
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
start_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
}
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
// torch.utils.benchmark.Timer(), there is a sync inside each timer callback
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
)
CK_TILE_HOST
void
stop
(
const
hipStream_t
&
s
)
{
{
HIP_CHECK_ERROR
(
hip
Device
Synchronize
());
HIP_CHECK_ERROR
(
hip
Stream
Synchronize
(
s
));
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
stop_tick
=
std
::
chrono
::
high_resolution_clock
::
now
();
}
}
// return in ms
// return in ms
...
...
include/ck_tile/ops/fmha.hpp
View file @
6bbb94f4
...
@@ -10,6 +10,10 @@
...
@@ -10,6 +10,10 @@
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp"
...
@@ -22,6 +26,12 @@
...
@@ -22,6 +26,12 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
...
...
Prev
1
2
3
4
5
6
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment