Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
5ec6a912
Commit
5ec6a912
authored
Jun 27, 2024
by
Jun Liu
Browse files
Merge branch 'develop' into amd-develop
parents
d39c3f5d
3bb0fe6c
Changes
226
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1856 additions
and
201 deletions
+1856
-201
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+5
-0
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
+4
-4
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
...or_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
+18
-23
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
+11
-12
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
+17
-19
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
+7
-8
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
+13
-16
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
...operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
+11
-14
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+506
-9
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
...e/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
+17
-13
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
..._operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
+4
-5
include/ck/tensor_operation/gpu/device/helper.hpp
include/ck/tensor_operation/gpu/device/helper.hpp
+359
-0
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
...gen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
+781
-0
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+16
-7
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
...ion/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
+4
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+5
-4
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
...operation/gpu/device/impl/device_column_to_image_impl.hpp
+3
-2
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
...ice/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
+17
-25
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
...evice/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
+20
-27
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
...or_operation/gpu/device/impl/device_contraction_utils.hpp
+38
-10
No files found.
include/ck/host_utility/device_prop.hpp
View file @
5ec6a912
...
...
@@ -84,4 +84,9 @@ inline bool is_gfx11_supported()
ck
::
get_device_name
()
==
"gfx1102"
||
ck
::
get_device_name
()
==
"gfx1103"
;
}
inline
bool
is_gfx12_supported
()
{
return
ck
::
get_device_name
()
==
"gfx1200"
||
ck
::
get_device_name
()
==
"gfx1201"
;
}
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_dpp.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -300,9 +300,9 @@ struct BlockwiseGemmDpp_ak0mak1_bk0nbk1_m0n0m1n1m2n2
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
dpp_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
dpp_input_type
>(),
b_thread_vec
.
template
AsType
<
dpp_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
dpp_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
dpp_input_type
>(),
b_thread_vec
.
template
AsType
<
dpp_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -613,7 +613,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -681,7 +681,7 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -749,10 +749,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -808,10 +807,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -840,10 +838,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -901,10 +898,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -939,10 +935,9 @@ struct BlockwiseGemmXdlops_pipeline_v4
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v1.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -144,12 +144,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
@@ -259,7 +259,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -319,10 +319,9 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -446,12 +445,12 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
PrefetchStages
=
1
;
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
@@ -584,7 +583,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -668,7 +667,7 @@ struct BlockwiseGemmXdlops_pipeline_v1<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v2.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -153,12 +153,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
...
...
@@ -303,7 +303,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -374,7 +374,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -428,10 +428,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -480,10 +479,9 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -646,12 +644,12 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
PrefetchStages
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
PrefetchStages
==
1
)
{
...
...
@@ -821,7 +819,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -914,7 +912,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -990,7 +988,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -1066,7 +1064,7 @@ struct BlockwiseGemmXdlops_pipeline_v2<BlockGemmPipelineScheduler::Interwave,
block_sync_lds
();
__builtin_amdgcn_sched_barrier
(
0
);
}
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v3.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -146,12 +146,12 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
PrefillStages
=
1
;
static
constexpr
index_t
GlobalBufferNum
=
1
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
ignore
=
num_loop
;
return
TailNumber
::
Full
;
...
...
@@ -381,7 +381,7 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -440,10 +440,9 @@ struct BlockwiseGemmXdlops_pipeline_v3<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v4.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -147,12 +147,12 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
static
constexpr
index_t
GlobalBufferNum
=
2
;
static
constexpr
index_t
HotloopUnroll
=
2
;
__host__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
__host__
__device__
static
constexpr
bool
BlockHasHotloop
(
index_t
num_loop
)
{
return
num_loop
>
PrefetchStages
;
}
__host__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
__host__
__device__
static
constexpr
TailNumber
BlockLoopTailNum
(
index_t
num_loop
)
{
if
(
num_loop
%
HotloopUnroll
==
1
)
{
...
...
@@ -403,7 +403,7 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -472,10 +472,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -529,10 +528,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -562,10 +560,9 @@ struct BlockwiseGemmXdlops_pipeline_v4<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_pipeline_xdlops_v5.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -444,7 +444,7 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
...
...
@@ -513,10 +513,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
a_block_desc_m0_m1_m2_k
,
...
...
@@ -564,10 +563,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
a_thread_copy_
.
Run
(
...
...
@@ -607,10 +605,9 @@ struct BlockwiseGemmXdlops_pipeline_v5<BlockGemmPipelineScheduler::Intrawave,
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -13,6 +13,504 @@
namespace
ck
{
#ifdef __gfx12__
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
typename
FloatAcc
,
typename
ABlockDesc
,
typename
BBlockDesc
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
MPerWMMA
,
index_t
NPerWMMA
,
index_t
MRepeat
,
index_t
NRepeat
,
index_t
KPack
,
bool
AEnableLds
=
true
,
bool
BEnableLds
=
true
,
bool
TransposeC
=
false
>
/* Option: Read from LDS, big buffer hold all threads required data
* Source
* A: K0PerBlock x MPerBlock x K1
* B: K0PerBlock x NPerBlock x K1
* Destination
* C, non-transpose
* thread level: MRepeat x NRepeat x MAccVgprs
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
* KPACK == WMMA_K = 16
*
* Option: Read from VMEM, small buffer hold each thread own required data (Skip LDS)
* Source:
* A(if skip LDS): MRepeat x KPack
* B(if skip LDS): NRepeat x KPack
* Destination
* C, non-transpose
* block level: MRepeat x MWave x MSubGroup x NRepeat x NWave x NThreadPerSubGroup x MAccVgprs
*/
struct
BlockwiseGemmWMMA
{
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
I4
=
Number
<
4
>
{};
static
constexpr
auto
I5
=
Number
<
5
>
{};
static
constexpr
auto
WmmaK
=
Number
<
16
>
{};
using
ThisThreadBlock
=
ThisThreadBlock
<
BlockSize
>
;
// Hardcode of WaveSize, since current HIP Runtime(5.4.0-10984) could not return correct one.
static
constexpr
index_t
WaveSize
=
32
;
// When use LDS, each Row(16 consecutive lanes) read whole data from source buffer
// When not use LDS, each Row read half of whole data from source buffer, exchange the data via
// permutation
static
constexpr
index_t
A_KRow
=
2
;
static
constexpr
index_t
B_KRow
=
2
;
static
constexpr
index_t
A_K1
=
ABlockDesc
{}.
GetLength
(
I5
);
static
constexpr
index_t
B_K1
=
BBlockDesc
{}.
GetLength
(
I5
);
static
constexpr
auto
wmma_gemm
=
WmmaGemm
<
FloatA
,
FloatB
,
FloatAcc
,
MPerWMMA
,
NPerWMMA
,
KPack
,
TransposeC
>
{};
static
constexpr
index_t
MWaves
=
MPerBlock
/
(
MRepeat
*
MPerWMMA
);
static
constexpr
index_t
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWMMA
);
StaticBufferTupleOfVector
<
AddressSpaceEnum
::
Vgpr
,
FloatAcc
,
MRepeat
*
NRepeat
,
wmma_gemm
.
GetRegSizePerWmma
(),
true
>
c_thread_buf_
;
__host__
__device__
constexpr
auto
&
GetCThreadBuffer
()
{
return
c_thread_buf_
;
}
__device__
static
auto
GetWaveIdx
()
{
const
index_t
thread_id
=
ThisThreadBlock
::
GetThreadId
();
constexpr
auto
threadid_to_wave_idx_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_merge_transform
(
make_tuple
(
MWaves
,
NWaves
,
WaveSize
))),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}),
make_tuple
(
Sequence
<
0
>
{}));
return
threadid_to_wave_idx_adaptor
.
CalculateBottomIndex
(
make_multi_index
(
thread_id
));
}
// Default, Block buffer in LDS, thread level offset enabled
__device__
static
auto
CalculateAThreadOriginDataIndex
()
{
if
constexpr
(
AEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
WMMA_a_idx
=
wmma_gemm
.
CalculateAThreadOriginDataIndex
();
// |KRepeat |MRepeat|MWave |KRow |MLane |KPack
return
make_tuple
(
0
,
0
,
waveId_m
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_a_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
__device__
static
auto
CalculateBThreadOriginDataIndex
()
{
if
constexpr
(
BEnableLds
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
WMMA_b_idx
=
wmma_gemm
.
CalculateBThreadOriginDataIndex
();
// |KRepeat |NRepeat|Nwave |KRow |NLane |KPack
return
make_tuple
(
0
,
0
,
waveId_n
,
wmma_gemm
.
GetSubGroupId
(),
WMMA_b_idx
,
0
);
}
else
{
return
make_tuple
(
0
,
0
,
0
,
0
,
0
,
0
);
}
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk
();
constexpr
auto
mrepeat_mwave_mperWMMA_to_m_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
MRepeat
,
MWaves
,
MPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
constexpr
auto
nrepeat_nwave_nperWMMA_to_n_adaptor
=
make_single_stage_tensor_adaptor
(
make_tuple
(
make_unmerge_transform
(
make_tuple
(
NRepeat
,
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{}));
const
index_t
c_thread_m
=
mrepeat_mwave_mperWMMA_to_m_adaptor
.
CalculateBottomIndex
(
make_tuple
(
m0
,
waveId_m
,
blk_idx
[
I0
]))[
I0
];
const
index_t
c_thread_n
=
nrepeat_nwave_nperWMMA_to_n_adaptor
.
CalculateBottomIndex
(
make_tuple
(
n0
,
waveId_n
,
blk_idx
[
I1
]))[
I0
];
return
make_tuple
(
c_thread_m
,
c_thread_n
);
}
template
<
index_t
m0
,
index_t
n0
>
__device__
static
auto
CalculateCThreadOriginDataIndex7D
(
Number
<
m0
>
,
Number
<
n0
>
)
{
const
auto
wave_idx
=
GetWaveIdx
();
const
auto
waveId_m
=
wave_idx
[
I0
];
const
auto
waveId_n
=
wave_idx
[
I1
];
const
auto
blk_idx
=
wmma_gemm
.
GetBeginOfThreadBlk3D
();
return
make_tuple
(
Number
<
m0
>
{},
waveId_m
,
blk_idx
[
I0
],
Number
<
n0
>
{},
waveId_n
,
blk_idx
[
I1
],
blk_idx
[
I2
]);
}
using
Tuple6
=
decltype
(
CalculateAThreadOriginDataIndex
());
__host__
__device__
BlockwiseGemmWMMA
(
Tuple6
a_origin
=
CalculateAThreadOriginDataIndex
(),
Tuple6
b_origin
=
CalculateBThreadOriginDataIndex
())
:
a_thread_copy_
(
a_origin
),
b_thread_copy_
(
b_origin
)
{
static_assert
(
ABlockDesc
::
IsKnownAtCompileTime
()
&&
BBlockDesc
::
IsKnownAtCompileTime
(),
"wrong! Desc should be known at compile-time"
);
static_assert
(
ThisThreadBlock
::
GetNumOfThread
()
==
MWaves
*
NWaves
*
WaveSize
,
"ThisThreadBlock::GetNumOfThread() != MWaves * NWaves * WaveSize
\n
"
);
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
NAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
return
make_naive_tensor_descriptor_packed
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
NAccVgprs
));
}
// Thread level, register decriptor. Vector-write
__host__
__device__
static
constexpr
auto
GetCThreadDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
=
wmma_gemm
.
GetCMSubGroupNThreadPerSubGroupMAccVgprsThreadBlkLengths
();
constexpr
auto
MAccVgprs
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I2
];
constexpr
auto
AccStride
=
c_msubgroup_nthreadpersubgroup_maccvgprs_tblk_lens
[
I3
];
return
make_naive_tensor_descriptor
(
// |MRepeat |MWave |MSubGroup |NRepeat |NWave
// |NThreadPerSubGroup |MAccVgprs
make_tuple
(
Number
<
MRepeat
>
{},
I1
,
I1
,
Number
<
NRepeat
>
{},
I1
,
I1
,
MAccVgprs
),
make_tuple
(
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
Number
<
NRepeat
>
{}
*
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
MAccVgprs
*
AccStride
,
AccStride
));
}
template
<
typename
CGridDesc_M_N
>
__host__
__device__
static
constexpr
auto
MakeCGridDescriptor_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
const
CGridDesc_M_N
&
c_grid_desc_m_n
)
{
const
auto
M
=
c_grid_desc_m_n
.
GetLength
(
I0
);
const
auto
N
=
c_grid_desc_m_n
.
GetLength
(
I1
);
const
auto
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
=
transform_tensor_descriptor
(
c_grid_desc_m_n
,
make_tuple
(
make_unmerge_transform
(
make_tuple
(
M
/
(
MWaves
*
MPerWMMA
),
MWaves
,
MPerWMMA
)),
make_unmerge_transform
(
make_tuple
(
N
/
(
NWaves
*
NPerWMMA
),
NWaves
,
NPerWMMA
))),
make_tuple
(
Sequence
<
0
>
{},
Sequence
<
1
>
{}),
make_tuple
(
Sequence
<
0
,
1
,
2
>
{},
Sequence
<
3
,
4
,
5
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_grid_desc_mblockxrepeat_mwave_mperwmma_nblockxrepeat_nwave_nperwmma
);
}
// transposed WMMA output C' = B' * A'
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MThreadPerSubGroup_NRepeat_NWave_NSubGroup_NAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MThreadPerSubGroup_NBlockxRepeat_NWave_NSubGroup_NAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Provide dimension size
__host__
__device__
static
constexpr
auto
GetCBlockDescriptor_MRepeat_MWave_MSubGroup_NRepeat_NWave_NThreadPerSubGroup_MAccVgprs
()
{
constexpr
auto
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
MWaves
>
{},
Number
<
MPerWMMA
>
{},
Number
<
NRepeat
>
{},
Number
<
NWaves
>
{},
Number
<
NPerWMMA
>
{}));
return
wmma_gemm
.
MakeCDesc_MBlockxRepeat_MWave_MSubGroup_NBlockxRepeat_NWave_NThreadPerSubGroup_MAccVgprs
(
c_block_desc_mrepeat_mwave_mperwmma_nrepeat_nwave_nperwmma
);
}
// Describe how data allocated in thread copy src buffer
// M0_M1_M2 = MRepeat_MWave_MPerWmma, N0_N1_N2 = NRepeat_NWave_NPerWmma
static
constexpr
ABlockDesc
a_block_desc_k0_m0_m1_m2_k1
;
static
constexpr
BBlockDesc
b_block_desc_k0_n0_n1_n2_k1
;
template
<
typename
ABlockBuffer
,
typename
BBlockBuffer
,
typename
CThreadBuffer
>
__device__
void
Run
(
const
ABlockBuffer
&
a_block_buf
,
const
BBlockBuffer
&
b_block_buf
,
CThreadBuffer
&
c_thread_buf
)
const
{
auto
a_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatA
>
(
a_thread_desc_
.
GetElementSpaceSize
());
auto
b_thread_buf
=
make_static_buffer
<
AddressSpaceEnum
::
Vgpr
,
FloatB
>
(
b_thread_desc_
.
GetElementSpaceSize
());
static_assert
(
KPack
%
(
A_K1
*
A_KRow
)
==
0
,
""
);
static_assert
(
KPack
%
(
B_K1
*
B_KRow
)
==
0
,
""
);
// basic intrinsic to determine loopover direction
if
constexpr
(
MRepeat
<
NRepeat
)
{
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}(
[
&
](
auto
k
)
{
// k=0,1,2 instead of k=0,kpack*1, ...
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
});
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
else
{
static_for
<
0
,
NRepeat
,
1
>
{}([
&
](
auto
n0
)
{
static_for
<
0
,
MRepeat
,
1
>
{}([
&
](
auto
m0
)
{
static_for
<
0
,
KPerBlock
/
KPack
,
1
>
{}([
&
](
auto
k
)
{
// k=0,1,2 instead of
// k=0,kpack*1, ..
// read B
b_thread_copy_
.
Run
(
b_block_desc_k0_n0_n1_n2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
B_K1
/
B_KRow
>
{},
n0
,
I0
,
I0
,
I0
,
I0
),
b_block_buf
,
b_thread_desc_
,
make_tuple
(
I0
,
n0
,
I0
,
I0
,
I0
,
I0
),
b_thread_buf
);
// read A
a_thread_copy_
.
Run
(
a_block_desc_k0_m0_m1_m2_k1
,
make_tuple
(
Number
<
k
*
KPack
/
A_K1
/
A_KRow
>
{},
m0
,
I0
,
I0
,
I0
,
I0
),
a_block_buf
,
a_thread_desc_
,
make_tuple
(
I0
,
m0
,
I0
,
I0
,
I0
,
I0
),
a_thread_buf
);
vector_type
<
FloatA
,
KPack
/
A_KRow
>
a_thread_vec
;
vector_type
<
FloatB
,
KPack
/
B_KRow
>
b_thread_vec
;
static_for
<
0
,
KPack
/
A_KRow
,
1
>
{}([
&
](
auto
i
)
{
a_thread_vec
.
template
AsType
<
FloatA
>()(
i
)
=
a_thread_buf
[
Number
<
a_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
A_K1
,
m0
,
0
,
0
,
0
,
i
%
A_K1
))
>
{}];
});
static_for
<
0
,
KPack
/
B_KRow
,
1
>
{}([
&
](
auto
i
)
{
b_thread_vec
.
template
AsType
<
FloatB
>()(
i
)
=
b_thread_buf
[
Number
<
b_thread_desc_
.
CalculateOffset
(
make_tuple
(
i
/
B_K1
,
n0
,
0
,
0
,
0
,
i
%
B_K1
))
>
{}];
});
using
wmma_input_type_a
=
typename
vector_type
<
FloatA
,
WmmaK
/
A_KRow
>::
type
;
using
wmma_input_type_b
=
typename
vector_type
<
FloatB
,
WmmaK
/
B_KRow
>::
type
;
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
});
});
});
}
}
protected:
static
constexpr
auto
a_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
A_K1
/
A_KRow
>
{},
Number
<
MRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
A_K1
>
{}),
make_tuple
(
Number
<
A_K1
>
{},
Number
<
KPack
/
A_KRow
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
A_K1
>
{},
Number
<
1
>
{}));
static
constexpr
auto
b_thread_desc_
=
make_naive_tensor_descriptor
(
make_tuple
(
Number
<
KPack
/
B_K1
/
B_KRow
>
{},
Number
<
NRepeat
>
{},
I1
,
I1
,
I1
,
Number
<
B_K1
>
{}),
make_tuple
(
Number
<
B_K1
>
{},
Number
<
KPack
/
B_KRow
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
B_K1
>
{},
Number
<
1
>
{}));
// C[M, N, NumRegWMMA]
static
constexpr
auto
c_thread_desc_
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
Number
<
MRepeat
>
{},
Number
<
NRepeat
>
{},
wmma_gemm
.
GetRegSizePerWmma
()));
template
<
bool
EnableLds
>
struct
AThreadCopySelector
;
template
<
>
struct
AThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
A_K1
>
;
};
template
<
>
struct
AThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
<
FloatA
,
FloatA
,
decltype
(
a_block_desc_k0_m0_m1_m2_k1
),
decltype
(
a_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
A_K1
/
A_KRow
,
1
,
1
,
1
,
1
,
A_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
A_K1
,
false
>
;
};
template
<
bool
EnableLds
>
struct
BThreadCopySelector
;
template
<
>
struct
BThreadCopySelector
<
true
>
{
using
type
=
ThreadwiseTensorSliceTransfer_v4
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
B_K1
>
;
};
template
<
>
struct
BThreadCopySelector
<
false
>
{
using
type
=
ThreadwiseTensorSliceTransfer_StaticToStatic_IntraRow
<
FloatB
,
FloatB
,
decltype
(
b_block_desc_k0_n0_n1_n2_k1
),
decltype
(
b_thread_desc_
),
tensor_operation
::
element_wise
::
PassThrough
,
Sequence
<
KPack
/
B_K1
/
B_KRow
,
1
,
1
,
1
,
1
,
B_K1
>
,
Sequence
<
0
,
1
,
2
,
3
,
4
,
5
>
,
5
,
B_K1
,
false
>
;
};
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
#else
template
<
index_t
BlockSize
,
typename
FloatA
,
typename
FloatB
,
...
...
@@ -352,10 +850,9 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
wmma_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -411,10 +908,9 @@ struct BlockwiseGemmWMMA
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
wmma_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
wmma_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
wmma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
wmma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -529,5 +1025,6 @@ struct BlockwiseGemmWMMA
typename
AThreadCopySelector
<
AEnableLds
>::
type
a_thread_copy_
;
typename
BThreadCopySelector
<
BEnableLds
>::
type
b_thread_copy_
;
};
#endif
}
// namespace ck
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -340,10 +340,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
@@ -488,7 +487,14 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// sync point.
if
constexpr
(
k
.
value
!=
0
||
KPerInnerLoop
==
KPerThread
)
{
#ifdef __gfx12__
asm
volatile
(
"\
s_barrier_signal -1
\n
\
s_barrier_wait -1 \
"
::
);
#else
asm
volatile
(
"s_barrier"
::
);
#endif
__builtin_amdgcn_sched_barrier
(
0
);
}
static_for
<
0
,
KPerInnerLoop
,
KPack
>
{}([
&
](
auto
k_
)
{
...
...
@@ -530,10 +536,9 @@ struct BlockwiseGemmXdlopsInterwave_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1
// TODO: insert setprio in more precise manner since we
// could have more than >1 MFMA instructions in single call
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type_a
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type_b
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
if
constexpr
(
k_
.
value
==
0
&&
m0
.
value
==
0
&&
n0
.
value
==
0
)
{
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -963,10 +968,9 @@ struct BlockwiseGemmXdlops_v2
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_xdlops_skip_b_lds.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -281,10 +281,9 @@ struct BlockwiseGemmXdlops_k0mk1_k0nk1_m0n0m1n1m2m3m4n2_v1r1
constexpr
index_t
c_offset
=
c_thread_desc_
.
CalculateOffset
(
make_tuple
(
m0
,
n0
,
0
));
xdlops_gemm
.
template
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>{}));
xdlops_gemm
.
Run
(
a_thread_vec
.
template
AsType
<
mfma_input_type
>(),
b_thread_vec
.
template
AsType
<
mfma_input_type
>(),
c_thread_buf
.
GetVectorTypeReference
(
Number
<
c_offset
>
{}));
});
});
});
...
...
include/ck/tensor_operation/gpu/device/helper.hpp
0 → 100644
View file @
5ec6a912
#pragma once
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/multi_index_transform_helper.hpp"
#include "ck/tensor_operation/gpu/grid/block_to_ctile_map.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include <fstream>
#include <variant>
// functions to return the corresponding structs based on generated template parameters
using
layouts
=
std
::
variant
<
ck
::
tensor_layout
::
convolution
::
GNWK
,
ck
::
tensor_layout
::
convolution
::
GNHWK
,
ck
::
tensor_layout
::
convolution
::
NHWGK
,
ck
::
tensor_layout
::
convolution
::
GNDHWK
,
ck
::
tensor_layout
::
convolution
::
NDHWGK
>
;
// return the layout type: currently this is the only type supported in MIOpen
auto
layout_type
(
std
::
string
type
)
{
if
(
type
==
"ck::tensor_layout::convolution::NHWGK"
)
{
return
ck
::
tensor_layout
::
convolution
::
NHWGK
{};
}
throw
std
::
runtime_error
(
"Incorrect layout"
);
}
// return the right gemm spec based on the generated template parameters
ck
::
tensor_operation
::
device
::
GemmSpecialization
gemm_type
(
std
::
string
type
)
{
if
(
type
==
"ck::tensor_operation::device::GemmSpecialization::Default"
)
{
return
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
}
if
(
type
==
"ck::tensor_operation::device::GemmSpecialization::MNKPadding"
)
{
return
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
}
throw
std
::
runtime_error
(
"Incorrect gemm spec: "
+
type
);
}
// return the type of convolution
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
conv_type
(
std
::
string
type
)
{
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Default"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
;
}
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Pad0"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
;
}
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::Filter1x1Stride1Pad0"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
;
}
if
(
type
==
"ck::tensor_operation::device::ConvolutionForwardSpecialization::OddC"
)
{
return
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
;
}
throw
std
::
runtime_error
(
"Incorrect conv spec: "
+
type
);
}
// Function to call on MatrixPadder via a wrapper struct
// NOTE: CK only uses MNKPadding for forward convolution
template
<
typename
CDesc_MRaw_NRaw
>
auto
pad
(
ck
::
index_t
mpb
,
ck
::
index_t
npb
,
ck
::
index_t
kpb
,
ck
::
tensor_operation
::
device
::
GemmSpecialization
gemm
,
CDesc_MRaw_NRaw
conv
)
{
if
(
gemm
==
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
)
{
ck
::
tensor_operation
::
device
::
MatrixPadder
<
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
,
ck
::
index_t
,
ck
::
index_t
,
ck
::
index_t
>
a
;
a
.
MPerTile_
=
mpb
;
a
.
NPerTile_
=
npb
;
a
.
KPerTile_
=
kpb
;
auto
tmp
=
grid_desc
(
a
,
conv
);
return
tmp
;
}
throw
std
::
runtime_error
(
"Incorrect template parameters, check gemm spec"
);
}
// Functions to call on TransformConvFwdToGemm through wrapper: different functions based on num
// dims
// FIXME: add a way to properly pass in the layout
auto
transform_conv
(
ck
::
index_t
num_dim
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
spec
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_strides
)
{
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
2
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
2
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
throw
std
::
runtime_error
(
"Incorrect conv spec"
);
}
auto
transform_conv_3d
(
ck
::
index_t
num_dim
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
spec
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_strides
)
{
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
3
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
3
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
throw
std
::
runtime_error
(
"Incorrect conv spec"
);
}
auto
transform_conv_1d
(
ck
::
index_t
num_dim
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
spec
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_strides
)
{
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Default
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
Filter1x1Stride1Pad0
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
if
(
num_dim
==
1
&&
spec
==
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
)
{
ck
::
tensor_operation
::
TransformConvFwdToGemm
<
1
,
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
::
OddC
>
conv_fwd
;
auto
res
=
ck
::
tensor_operation
::
TransformConv
();
return
res
.
transform_func
(
out_lengths
,
out_strides
,
conv_fwd
);
}
throw
std
::
runtime_error
(
"Incorrect dims or conv spec"
);
}
template
<
typename
CGridDesc_M_N
>
auto
block_2_etile
(
ck
::
index_t
m_per_block
,
ck
::
index_t
n_per_block
,
CGridDesc_M_N
matrix_padder
)
{
if
(
m_per_block
==
32
&&
n_per_block
==
64
)
{
auto
b2e
=
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
32
,
64
,
CGridDesc_M_N
>
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
32
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
32
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
64
&&
n_per_block
==
32
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
64
,
32
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
64
&&
n_per_block
==
64
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
64
,
64
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
64
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
64
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
32
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
32
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
64
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
64
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
128
&&
n_per_block
==
256
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
128
,
256
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
if
(
m_per_block
==
256
&&
n_per_block
==
128
)
{
ck
::
BlockToCTileMap_M00_N0_M01Adapt
<
256
,
128
,
CGridDesc_M_N
>
b2e
(
matrix_padder
);
return
b2e
.
CalculateGridSize
(
matrix_padder
);
}
throw
std
::
runtime_error
(
"Incorrect template parameters"
);
}
// wrapper functions by dims to get grid size - uses above 3 functions
// TODO: eventually remove the 1d/2d versions as CK will only support 3d convolutions
auto
get_launch_params_1d
(
ck
::
host
::
Solution
solution
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
4
>
out_strides
)
{
auto
num_dim
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NumDim"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NPerBlock"
);
auto
k_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"KPerBlock"
);
auto
GemmType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"GemmSpecialization"
);
auto
ConvType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"ConvSpecialization"
);
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
=
gemm_type
(
GemmType
);
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
ConvSpec
=
conv_type
(
ConvType
);
auto
conv_to_gemm_transformer
=
transform_conv_1d
(
num_dim
,
ConvSpec
,
out_lengths
,
out_strides
);
auto
matrix_padder
=
pad
(
m_per_block
,
n_per_block
,
k_per_block
,
GemmSpec
,
conv_to_gemm_transformer
);
auto
b2e
=
block_2_etile
(
m_per_block
,
n_per_block
,
matrix_padder
);
return
b2e
;
}
auto
get_launch_params
(
ck
::
host
::
Solution
solution
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
5
>
out_strides
)
{
auto
num_dim
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NumDim"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NPerBlock"
);
auto
k_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"KPerBlock"
);
auto
GemmType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"GemmSpecialization"
);
auto
ConvType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"ConvSpecialization"
);
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
=
gemm_type
(
GemmType
);
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
ConvSpec
=
conv_type
(
ConvType
);
auto
conv_to_gemm_transformer
=
transform_conv
(
num_dim
,
ConvSpec
,
out_lengths
,
out_strides
);
auto
matrix_padder
=
pad
(
m_per_block
,
n_per_block
,
k_per_block
,
GemmSpec
,
conv_to_gemm_transformer
);
auto
b2e
=
block_2_etile
(
m_per_block
,
n_per_block
,
matrix_padder
);
return
b2e
;
}
auto
get_launch_params_3d
(
ck
::
host
::
Solution
solution
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_lengths
,
ck
::
Array
<
ck
::
index_t
,
6
>
out_strides
)
{
auto
num_dim
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NumDim"
);
auto
m_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"MPerBlock"
);
auto
n_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"NPerBlock"
);
auto
k_per_block
=
solution
.
GetTemplateParameter
<
ck
::
index_t
>
(
"KPerBlock"
);
auto
GemmType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"GemmSpecialization"
);
auto
ConvType
=
solution
.
GetTemplateParameter
<
std
::
string
>
(
"ConvSpecialization"
);
ck
::
tensor_operation
::
device
::
GemmSpecialization
GemmSpec
=
gemm_type
(
GemmType
);
ck
::
tensor_operation
::
device
::
ConvolutionForwardSpecialization
ConvSpec
=
conv_type
(
ConvType
);
auto
conv_to_gemm_transformer
=
transform_conv_3d
(
num_dim
,
ConvSpec
,
out_lengths
,
out_strides
);
auto
matrix_padder
=
pad
(
m_per_block
,
n_per_block
,
k_per_block
,
GemmSpec
,
conv_to_gemm_transformer
);
auto
b2e
=
block_2_etile
(
m_per_block
,
n_per_block
,
matrix_padder
);
return
b2e
;
}
include/ck/tensor_operation/gpu/device/impl/codegen_device_grouped_conv_fwd_multiple_abd_xdl_cshuffle.hpp
0 → 100644
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <functional>
#include <iostream>
#include <iterator>
#include <numeric>
#include <sstream>
#include "ck/utility/common_header.hpp"
#include "ck/tensor_description/tensor_descriptor.hpp"
#include "ck/tensor_description/tensor_descriptor_helper.hpp"
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/tensor_operation/gpu/device/convolution_forward_specialization.hpp"
#include "ck/tensor_operation/operator_transform/transform_conv_fwd_to_gemm.hpp"
#include "ck/tensor_operation/gpu/device/device_grouped_conv_fwd_multiple_abd.hpp"
#include "ck/tensor_operation/gpu/device/gemm_specialization.hpp"
#include "ck/tensor_operation/gpu/device/matrix_padder.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_abd_xdl_cshuffle.hpp"
#include "ck/tensor_operation/gpu/device/impl/device_grouped_conv_utils.hpp"
#include "ck/host_utility/device_prop.hpp"
#include "ck/host_utility/kernel_launch.hpp"
#include "ck/host_utility/io.hpp"
namespace
ck
{
namespace
tensor_operation
{
namespace
device
{
namespace
{
/*
* \brief Wrapper function of GridwiseGemm::Run to realize BatchedGEMM.
*
* \tparam ComputePtrOffsetOfBatch Class that computes the base pointer offsets of A, B, C matrix
* given the batch. For example, ComputePtrOffsetOfStridedBatch() computes the offsets of evenly
* strided batched, but we can easily extend to other layouts. The returned offset can be either \p
* index_t or \p long_index_t. If it returns \p long_index_t, we are not subject to the 2GB
* limitations.
*
* \tparam Block2ETileMap Block2ETileMap::CalculateBottomIndex() takes in id of a workgroup and
* returns the 2D index of the tile that it computes. \see
* GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3::Run().
*
* \note Using \p ComputePtrOffsetOfBatch gives us the flexibility that 2 workgroups can compute 2
* tiles from different matrices. Keep in mind that these 2 matrices can share the same grid
* descriptor (like in BatchedGEMM), or use their own grid descriptors (in GroupedGemm). \link
* impl/device_conv3d_fwd_xdl_ndhwc_kzyxc_ndhwk.hpp kernel_gemm_xdlops_v2r3_for_conv3d \endlink for
* \link DeviceConv3d \endlink uses the same concept, but currently does NOT encapsulate the
* computing of pointer offset into \p ComputePtrOffsetOfStridedBatch.
*
* \note \p Block2ETileMap allows customized mapping between a workgroup and the C-tile it computes.
* Together with \p ComputePtrOffsetOfBatch, we can reuse GridwiseGemm (and GridwiseGemm fusion ) to
* realize BatchedGemm and GroupedGemm (and the corresponding GEMM fusion).
*
*/
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
// tuples if multi AB, pointers if no
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
bool
isMultiA
,
bool
isMultiB
>
__device__
void
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx908__) || defined(__gfx90a__) || \
defined(__gfx94__))
// offset base pointer for each work-group
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
long_index_t
e_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetEPtrOffset
(
g_idx
)));
const
auto
&
ds_batch_offset
=
compute_ptr_offset_of_batch
.
GetDsPtrOffset
(
g_idx
);
__shared__
char
p_shared
[
GridwiseGemm
::
GetSharedMemoryNumberOfByte
()];
DsPointer
p_ds_grid_grp
;
static
constexpr
index_t
NumDTensor
=
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
::
Size
();
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
if
constexpr
(
isMultiA
||
isMultiB
)
{
AsPointer
p_as_grid_grp
;
BsPointer
p_bs_grid_grp
;
const
auto
&
as_batch_offset
=
compute_ptr_offset_of_batch
.
GetAsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumATensor
=
AGridDesc_AK0_M_AK1
::
Size
();
static_for
<
0
,
NumATensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_as_grid_grp
(
i
)
=
p_as_grid
[
i
]
+
as_batch_offset
[
i
];
});
const
auto
&
bs_batch_offset
=
compute_ptr_offset_of_batch
.
GetBsPtrOffset
(
g_idx
);
static
constexpr
index_t
NumBTensor
=
BGridDesc_BK0_N_BK1
::
Size
();
static_for
<
0
,
NumBTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_bs_grid_grp
(
i
)
=
p_bs_grid
[
i
]
+
bs_batch_offset
[
i
];
});
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid_grp
,
p_bs_grid_grp
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
}
else
{
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetBPtrOffset
(
g_idx
)));
GridwiseGemm
::
template
Run
<
HasMainKBlockLoop
>(
p_as_grid
+
a_batch_offset
,
p_bs_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_e_grid
+
e_batch_offset
,
p_shared
,
a_element_op
,
b_element_op
,
cde_element_op
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
);
}
#else
ignore
=
p_as_grid
;
ignore
=
p_bs_grid
;
ignore
=
p_ds_grid
;
ignore
=
p_e_grid
;
ignore
=
batch_count
;
ignore
=
a_grid_desc_k0_m_k1
;
ignore
=
b_grid_desc_k0_n_k1
;
ignore
=
ds_grid_desc_mblock_mperblock_nblock_nperblock
;
ignore
=
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
ignore
=
a_element_op
;
ignore
=
b_element_op
;
ignore
=
cde_element_op
;
ignore
=
compute_ptr_offset_of_batch
;
ignore
=
block_2_ctile_map
;
#endif
}
template
<
typename
GridwiseGemm
,
typename
AsPointer
,
// tuples if multi AB, pointers if no
typename
BsPointer
,
typename
DsPointer
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
typename
AGridDesc_AK0_M_AK1
,
typename
BGridDesc_BK0_N_BK1
,
typename
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
typename
Block2ETileMap
,
typename
ComputePtrOffsetOfBatch
,
bool
HasMainKBlockLoop
,
bool
isMultiA
,
bool
isMultiB
>
__global__
void
#if CK_USE_LAUNCH_BOUNDS
__launch_bounds__
(
CK_MAX_THREAD_PER_BLOCK
,
CK_MIN_BLOCK_PER_CU
)
#endif
kernel_grouped_conv_fwd_multiple_abd_xdl_cshuffle
(
AsPointer
p_as_grid
,
BsPointer
p_bs_grid
,
DsPointer
p_ds_grid
,
EDataType
*
__restrict__
p_e_grid
,
const
AElementwiseOperation
a_element_op
,
const
BElementwiseOperation
b_element_op
,
const
CDEElementwiseOperation
cde_element_op
,
const
index_t
batch_count
,
const
AGridDesc_AK0_M_AK1
a_grid_desc_k0_m_k1
,
const
BGridDesc_BK0_N_BK1
b_grid_desc_k0_n_k1
,
const
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
const
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
const
Block2ETileMap
block_2_ctile_map
,
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
)
{
device_grouped_conv_fwd_multiple_abd_xdl_cshuffle
<
GridwiseGemm
,
AsPointer
,
// tuples if multi AB, pointers if no
BsPointer
,
DsPointer
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
AGridDesc_AK0_M_AK1
,
BGridDesc_BK0_N_BK1
,
DsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
,
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
,
Block2ETileMap
,
ComputePtrOffsetOfBatch
,
HasMainKBlockLoop
,
isMultiA
,
isMultiB
>
(
p_as_grid
,
p_bs_grid
,
p_ds_grid
,
*
p_e_grid
,
a_element_op
,
b_element_op
,
cde_element_op
,
batch_count
,
a_grid_desc_k0_m_k1
,
b_grid_desc_k0_n_k1
,
ds_grid_desc_mblock_mperblock_nblock_nperblock
,
e_grid_desc_mblock_mperblock_nblock_nperblock_
,
block_2_ctile_map
,
compute_ptr_offset_of_batch
);
}
}
// namespace
template
<
typename
T
>
using
is_tuple
=
decltype
(
std
::
declval
<
T
&>
().
IsTuple
());
//
// @brief Device Convolution operation.
//
// Supports:
// @li Forward convolution with up to 3 spatial dimentions
// @li Input tensor in GNWC data format
// @li Weight tensor in GKXC data format
// @li Output tensor in GNWK data format
//
// 1D:
// out[N, Wo, K] = in[N, Wi, C] * wei[K, X, C]
// 2D:
// out[N, Ho, Wo, K] = in[N, Hi, Wi, C] * wei[K, Y, X, C]
// 3D:
// out[N, Do, Ho, Wo, K] = in[N, Di, Hi, Wi, C] * wei[K, Z, Y, X, C]
//
template
<
index_t
NDimSpatial
,
typename
ALayout
,
typename
BLayout
,
typename
DsLayout
,
typename
ELayout
,
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CShuffleDataType
,
typename
DsDataType
,
typename
EDataType
,
typename
AElementwiseOperation
,
typename
BElementwiseOperation
,
typename
CDEElementwiseOperation
,
ConvolutionForwardSpecialization
ConvForwardSpecialization
,
GemmSpecialization
GemmSpec
,
index_t
NumGemmKPrefetchStage
,
index_t
BlockSize
,
index_t
MPerBlock
,
index_t
NPerBlock
,
index_t
KPerBlock
,
index_t
AK1
,
index_t
BK1
,
index_t
MPerXDL
,
index_t
NPerXDL
,
index_t
MXdlPerWave
,
index_t
NXdlPerWave
,
typename
ABlockTransferThreadClusterLengths_AK0_M_AK1
,
typename
ABlockTransferThreadClusterArrangeOrder
,
typename
ABlockTransferSrcAccessOrder
,
index_t
ABlockTransferSrcVectorDim
,
index_t
ABlockTransferSrcScalarPerVector
,
index_t
ABlockTransferDstScalarPerVector_AK1
,
index_t
ABlockLdsExtraM
,
typename
BBlockTransferThreadClusterLengths_BK0_N_BK1
,
typename
BBlockTransferThreadClusterArrangeOrder
,
typename
BBlockTransferSrcAccessOrder
,
index_t
BBlockTransferSrcVectorDim
,
index_t
BBlockTransferSrcScalarPerVector
,
index_t
BBlockTransferDstScalarPerVector_BK1
,
index_t
BBlockLdsExtraN
,
index_t
CShuffleMXdlPerWavePerShuffle
,
index_t
CShuffleNXdlPerWavePerShuffle
,
typename
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock
,
index_t
CDEBlockTransferScalarPerVector_NPerBlock
,
typename
ComputeDataType
=
decltype
(
UnpackDataType
<
is_detected
<
is_tuple
,
ADataType
>
::
value
,
Number
<
0
>
,
ADataType
>
()),
// ComputeType is InputType by default (first
// in tuple for MultiAB), unpack if tuple was
// passed
LoopScheduler
LoopSched
=
make_default_loop_scheduler
()
>
struct
CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
:
public
DeviceGroupedConvFwdMultipleABD
<
NDimSpatial
,
ALayout
,
BLayout
,
DsLayout
,
ELayout
,
ADataType
,
BDataType
,
DsDataType
,
EDataType
,
AElementwiseOperation
,
BElementwiseOperation
,
CDEElementwiseOperation
,
ComputeDataType
>
{
using
DeviceOp
=
CodegenDeviceGroupedConvFwdMultipleABD_Xdl_CShuffle
;
static
constexpr
bool
isMultiA
=
is_detected
<
is_tuple
,
ADataType
>::
value
;
static
constexpr
bool
isMultiB
=
is_detected
<
is_tuple
,
BDataType
>::
value
;
static
constexpr
index_t
NumATensor
=
GetNumABTensors
<
isMultiA
,
ADataType
>
();
static
constexpr
index_t
NumBTensor
=
GetNumABTensors
<
isMultiB
,
BDataType
>
();
static
constexpr
index_t
NumDTensor
=
DsDataType
::
Size
();
static
constexpr
auto
I0
=
Number
<
0
>
{};
static
constexpr
auto
I1
=
Number
<
1
>
{};
static
constexpr
auto
I2
=
Number
<
2
>
{};
static
constexpr
auto
I3
=
Number
<
3
>
{};
static
constexpr
auto
conv_to_gemm_transformer
=
TransformConvFwdToGemm
<
NDimSpatial
,
ConvForwardSpecialization
>
{};
static
constexpr
auto
matrix_padder
=
MatrixPadder
<
GemmSpec
,
index_t
,
index_t
,
index_t
>
{
MPerBlock
,
NPerBlock
,
KPerBlock
};
template
<
typename
ALay
>
__host__
__device__
static
auto
MakeAGridDescriptor_M_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
)
{
const
auto
in_gemmmraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeADescriptor_M_K
<
ALay
>(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
return
in_gemmm_gemmk_desc
;
}
template
<
typename
BLay
>
__host__
__device__
static
auto
MakeBGridDescriptor_N_K
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
)
{
const
auto
wei_gemmnraw_gemmkraw_desc
=
conv_to_gemm_transformer
.
template
MakeBDescriptor_N_K
<
BLay
>(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
);
const
auto
wei_gemmn_gemmk_desc
=
matrix_padder
.
PadBDescriptor_N_K
(
wei_gemmnraw_gemmkraw_desc
);
return
wei_gemmn_gemmk_desc
;
}
template
<
typename
ELay
>
__host__
__device__
static
auto
MakeEGridDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
)
{
const
auto
out_gemmmraw_gemmnraw_desc
=
conv_to_gemm_transformer
.
template
MakeCDescriptor_M_N
<
ELay
>(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
);
const
auto
out_gemmm_gemmn_desc
=
matrix_padder
.
PadCDescriptor_M_N
(
out_gemmmraw_gemmnraw_desc
);
return
out_gemmm_gemmn_desc
;
}
// Shape of Ds and E must be aligned. Strides can be different.
// Pass e_g_n_k_wos_lengths for logical broadcast.
__host__
__device__
static
auto
MakeDsGridDescriptor_M_N
(
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
)
{
return
generate_tuple
(
[
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
return
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]);
},
Number
<
NumDTensor
>
{});
}
// desc for problem definition
using
AGridDesc_M_K
=
remove_cvref_t
<
decltype
(
MakeAGridDescriptor_M_K
<
ALayout
>
(
{},
{},
{},
{},
{},
{},
{},
{},
{},
{}))
>
;
using
BGridDesc_N_K
=
remove_cvref_t
<
decltype
(
MakeBGridDescriptor_N_K
<
BLayout
>
({},
{}))
>
;
using
DsGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeDsGridDescriptor_M_N
({},
{}))
>
;
using
EGridDesc_M_N
=
remove_cvref_t
<
decltype
(
MakeEGridDescriptor_M_N
<
ELayout
>
({},
{}))
>
;
// If we are using multiAB and one of the template datatype parameters is not a tuple, convert
// it to it
using
GemmADataType
=
std
::
conditional_t
<!
isMultiA
&&
isMultiB
,
Tuple
<
ADataType
>
,
ADataType
>
;
using
GemmBDataType
=
std
::
conditional_t
<!
isMultiB
&&
isMultiA
,
Tuple
<
BDataType
>
,
BDataType
>
;
#define GridwiseGemmTemplateParameters \
GemmADataType, GemmBDataType, ComputeDataType, AccDataType, CShuffleDataType, DsDataType, \
EDataType, AElementwiseOperation, BElementwiseOperation, CDEElementwiseOperation, \
InMemoryDataOperationEnum::Set, NumGemmKPrefetchStage, BlockSize, MPerBlock, NPerBlock, \
KPerBlock, AK1, BK1, MPerXDL, NPerXDL, MXdlPerWave, NXdlPerWave, \
ABlockTransferThreadClusterLengths_AK0_M_AK1, ABlockTransferThreadClusterArrangeOrder, \
ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, \
ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_AK1, false, \
ABlockLdsExtraM, BBlockTransferThreadClusterLengths_BK0_N_BK1, \
BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, \
BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, \
BBlockTransferDstScalarPerVector_BK1, false, BBlockLdsExtraN, \
CShuffleMXdlPerWavePerShuffle, CShuffleNXdlPerWavePerShuffle, \
CDEBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, \
CDEBlockTransferScalarPerVector_NPerBlock, LoopSched
// Use appropriate gridwise gemm
using
GridwiseGemm
=
std
::
conditional_t
<
isMultiA
||
isMultiB
,
GridwiseGemmMultipleABD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>
,
GridwiseGemmMultipleD_xdl_cshuffle
<
GridwiseGemmTemplateParameters
>>
;
// If ADataTypes or BDataTypes is tuple, user has to pass ck::Array with pointers.
using
APointers
=
std
::
conditional_t
<
isMultiA
,
ck
::
Array
<
const
void
*
,
NumATensor
>&
,
const
void
*>
;
using
BPointers
=
std
::
conditional_t
<
isMultiB
,
ck
::
Array
<
const
void
*
,
NumBTensor
>&
,
const
void
*>
;
// Use Tuple for the both cases for GridPointer to initialize it in Argument constructor (not
// in initializer list what is required for single const pointer).
using
AGridPointer
=
remove_cvref_t
<
decltype
(
GetAGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
ADataType
>
())
>
;
using
BGridPointer
=
remove_cvref_t
<
decltype
(
GetBGridPointer
<
isMultiA
||
isMultiB
,
GridwiseGemm
,
BDataType
>
())
>
;
// desc for blockwise copy
using
AGridDesc_AK0_M_AK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
AGridDesc_M_K
{}))
>
;
using
BGridDesc_BK0_N_BK1
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
BGridDesc_N_K
{}))
>
;
using
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
DsGridDesc_M_N
{}))
>
;
using
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
EGridDesc_M_N
{}))
>
;
// block-to-e-tile map
using
Block2ETileMap
=
remove_cvref_t
<
decltype
(
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
EGridDesc_M_N
{}))
>
;
// Argument
struct
Argument
{
__device__
__host__
Argument
(
APointers
p_as
,
BPointers
p_bs
,
const
ck
::
Array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
:
p_as_grid_
{},
p_bs_grid_
{},
p_ds_grid_
{},
p_e_grid_
{
static_cast
<
EDataType
*>
(
p_e
)},
num_group_
{
a_g_n_c_wis_lengths
[
0
]},
a_grid_desc_m_k_
{
DeviceOp
::
MakeAGridDescriptor_M_K
<
ALayout
>
(
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
)},
b_grid_desc_n_k_
{
DeviceOp
::
MakeBGridDescriptor_N_K
<
BLayout
>
(
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
)},
ds_grid_desc_m_n_
{},
e_grid_desc_m_n_
{
DeviceOp
::
MakeEGridDescriptor_M_N
<
ELayout
>
(
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
)},
a_grid_desc_ak0_m_ak1_
{
GridwiseGemm
::
MakeDefaultAGridDescriptor_AK0_M_AK1
(
a_grid_desc_m_k_
)},
b_grid_desc_bk0_n_bk1_
{
GridwiseGemm
::
MakeDefaultBGridDescriptor_BK0_N_BK1
(
b_grid_desc_n_k_
)},
ds_grid_desc_mblock_mperblock_nblock_nperblock_
{},
e_grid_desc_mblock_mperblock_nblock_nperblock_
{},
block_2_etile_map_
{
GridwiseGemm
::
MakeDefaultBlock2ETileMap
(
e_grid_desc_m_n_
)},
compute_ptr_offset_of_batch_
{},
a_element_op_
{
a_element_op
},
b_element_op_
{
b_element_op
},
cde_element_op_
{
cde_element_op
},
a_g_n_c_wis_lengths_
{
a_g_n_c_wis_lengths
},
a_g_n_c_wis_strides_
{
a_g_n_c_wis_strides
},
b_g_k_c_xs_lengths_
{
b_g_k_c_xs_lengths
},
b_g_k_c_xs_strides_
{
b_g_k_c_xs_strides
},
ds_g_n_k_wos_lengths_
{
ds_g_n_k_wos_lengths
},
ds_g_n_k_wos_strides_
{
ds_g_n_k_wos_strides
},
e_g_n_k_wos_lengths_
{
e_g_n_k_wos_lengths
},
e_g_n_k_wos_strides_
{
e_g_n_k_wos_strides
},
conv_filter_strides_
{
conv_filter_strides
},
conv_filter_dilations_
{
conv_filter_dilations
},
input_left_pads_
{
input_left_pads
},
input_right_pads_
{
input_right_pads
}
{
// A/B/E Batch Stride
if
constexpr
(
isMultiA
||
isMultiB
)
{
static_for
<
0
,
NumATensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideA_
(
i
)
=
a_g_n_c_wis_strides
[
0
];
// Use GemmADataType/GemmBDataType to iterate over tuple (even if passed data
// type is not tuple)
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmADataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiA
)
{
// p_as is tuple
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
[
i
.
value
]);
}
else
{
// if MultiB and not MultiA then p_as is single pointer
p_as_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_as
);
}
});
static_for
<
0
,
NumBTensor
,
1
>
{}([
&
](
auto
i
)
{
// Init compute_ptr_offset_of_batch_ for multiple AB
compute_ptr_offset_of_batch_
.
BatchStrideB_
(
i
)
=
b_g_k_c_xs_strides
[
0
];
using
DataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
GemmBDataType
>>
;
// It is possible that one of the AB is a pointer and one is a tuple.
// Then also use multiAB but we have to cast single pointer instead of tuple of
// pointer.
if
constexpr
(
isMultiB
)
{
// p_bs is tuple
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
[
i
.
value
]);
}
else
{
// if MultiA and not MultiB then p_bs is single pointer
p_bs_grid_
(
i
)
=
static_cast
<
const
DataType
*>
(
p_bs
);
}
});
}
else
{
compute_ptr_offset_of_batch_
.
BatchStrideA_
=
a_g_n_c_wis_strides
[
0
];
compute_ptr_offset_of_batch_
.
BatchStrideB_
=
b_g_k_c_xs_strides
[
0
];
// p_as and p_bs are pointers
p_as_grid_
(
I0
)
=
static_cast
<
const
ADataType
*>
(
p_as
);
p_bs_grid_
(
I0
)
=
static_cast
<
const
BDataType
*>
(
p_bs
);
}
// populate pointer, batch stride, desc for Ds
static_for
<
0
,
NumDTensor
,
1
>
{}([
&
](
auto
i
)
{
using
DLayout
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsLayout
>>
;
using
DDataType
=
remove_cvref_t
<
tuple_element_t
<
i
.
value
,
DsDataType
>>
;
// D pointer
p_ds_grid_
(
i
)
=
static_cast
<
const
DDataType
*>
(
p_ds
[
i
]);
// D batch stride
compute_ptr_offset_of_batch_
.
BatchStrideDs_
(
i
)
=
ds_g_n_k_wos_strides
[
i
][
0
];
// D desc
ds_grid_desc_m_n_
(
i
)
=
DeviceOp
::
MakeEGridDescriptor_M_N
<
DLayout
>
(
e_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
[
i
]);
});
compute_ptr_offset_of_batch_
.
BatchStrideE_
=
e_g_n_k_wos_strides
[
0
];
// populate desc for Ds/E
if
constexpr
(
isMultiA
||
isMultiB
)
{
const
auto
as_grid_desc_ak0_m_ak1
=
generate_tuple
([
&
](
auto
)
{
return
a_grid_desc_m_k_
;
},
Number
<
NumATensor
>
{});
const
auto
bs_grid_desc_bk0_n_bk1
=
generate_tuple
([
&
](
auto
)
{
return
b_grid_desc_n_k_
;
},
Number
<
NumBTensor
>
{});
if
(
GridwiseGemm
::
CheckValidity
(
as_grid_desc_ak0_m_ak1
,
bs_grid_desc_bk0_n_bk1
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
}
else
{
if
(
GridwiseGemm
::
CheckValidity
(
a_grid_desc_m_k_
,
b_grid_desc_n_k_
,
ds_grid_desc_m_n_
,
e_grid_desc_m_n_
,
block_2_etile_map_
))
{
e_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeEGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
e_grid_desc_m_n_
);
ds_grid_desc_mblock_mperblock_nblock_nperblock_
=
GridwiseGemm
::
MakeDsGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
(
ds_grid_desc_m_n_
);
}
}
}
// private:
// pointers (tuple if multi AB, pointer if no)
AGridPointer
p_as_grid_
;
BGridPointer
p_bs_grid_
;
typename
GridwiseGemm
::
DsGridPointer
p_ds_grid_
;
EDataType
*
p_e_grid_
;
// tensor descriptors for problem definiton
index_t
num_group_
;
AGridDesc_M_K
a_grid_desc_m_k_
;
BGridDesc_N_K
b_grid_desc_n_k_
;
DsGridDesc_M_N
ds_grid_desc_m_n_
;
EGridDesc_M_N
e_grid_desc_m_n_
;
// tensor descriptors for block/thread-wise copy
AGridDesc_AK0_M_AK1
a_grid_desc_ak0_m_ak1_
;
BGridDesc_BK0_N_BK1
b_grid_desc_bk0_n_bk1_
;
DsGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
ds_grid_desc_mblock_mperblock_nblock_nperblock_
;
EGridDesc_MBlock_MPerBlock_NBlock_NPerBlock
e_grid_desc_mblock_mperblock_nblock_nperblock_
;
// block-to-e-tile map
Block2ETileMap
block_2_etile_map_
;
// for computing batch offset
ComputePtrOffsetOfStridedBatch
<
NumATensor
,
NumBTensor
,
NumDTensor
>
compute_ptr_offset_of_batch_
;
// element-wise op
AElementwiseOperation
a_element_op_
;
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// for checking IsSupportedArgument()
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_lengths_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
a_g_n_c_wis_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_lengths_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
b_g_k_c_xs_strides_
;
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_lengths_
;
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>
ds_g_n_k_wos_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_lengths_
;
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
e_g_n_k_wos_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
conv_filter_strides_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
conv_filter_dilations_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
input_left_pads_
;
ck
::
Array
<
index_t
,
NDimSpatial
>
input_right_pads_
;
};
static
__device__
__host__
auto
MakeArgument
(
APointers
p_as
,
BPointers
p_bs
,
const
ck
::
Array
<
const
void
*
,
NumDTensor
>&
p_ds
,
void
*
p_e
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
a_g_n_c_wis_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
b_g_k_c_xs_strides
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_lengths
,
const
ck
::
Array
<
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>
,
NumDTensor
>&
ds_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_lengths
,
const
ck
::
Array
<
index_t
,
NDimSpatial
+
3
>&
e_g_n_k_wos_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_strides
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
conv_filter_dilations
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_left_pads
,
const
ck
::
Array
<
index_t
,
NDimSpatial
>&
input_right_pads
,
const
AElementwiseOperation
&
a_element_op
,
const
BElementwiseOperation
&
b_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
)
{
return
Argument
{
p_as
,
p_bs
,
p_ds
,
p_e
,
a_g_n_c_wis_lengths
,
a_g_n_c_wis_strides
,
b_g_k_c_xs_lengths
,
b_g_k_c_xs_strides
,
ds_g_n_k_wos_lengths
,
ds_g_n_k_wos_strides
,
e_g_n_k_wos_lengths
,
e_g_n_k_wos_strides
,
conv_filter_strides
,
conv_filter_dilations
,
input_left_pads
,
input_right_pads
,
a_element_op
,
b_element_op
,
cde_element_op
};
}
};
}
// namespace device
}
// namespace tensor_operation
}
// namespace ck
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
5ec6a912
...
...
@@ -133,8 +133,13 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
NWaves
=
NPerBlock
/
(
NRepeat
*
NPerWmma
);
static
constexpr
auto
WmmaK
=
K1
==
16
?
32
:
16
;
static
constexpr
auto
AEnableLds_auto
=
NWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
MaxVectorLoadA
=
K1
*
sizeof
(
ADataType
)
==
16
?
true
:
false
;
static
constexpr
auto
MaxVectorLoadB
=
K1
*
sizeof
(
BDataType
)
==
16
?
true
:
false
;
static
constexpr
auto
AEnableLds_auto
=
(
NWaves
==
1
&&
(
MaxVectorLoadA
||
MRepeat
==
1
))
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
(
MWaves
==
1
&&
(
MaxVectorLoadB
||
NRepeat
==
1
))
?
false
:
true
;
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
false
;
...
...
@@ -829,7 +834,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
...
...
@@ -869,11 +874,15 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
}
else
{
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
if
(
!
(
arg
.
a_kz_stride_
==
1
))
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
return
false
;
index_t
LastK
=
AEnableLds
?
arg
.
a_grid_desc_
.
GetLength
(
I2
)
:
arg
.
a_grid_desc_
.
GetLength
(
I6
);
if
(
LastK
%
ABlockTransferSrcScalarPerVector
==
0
)
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
return
false
;
}
}
}
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_multiple_d_dl.hpp
View file @
5ec6a912
...
...
@@ -70,8 +70,9 @@ __global__ void
const
ComputePtrOffsetOfBatch
compute_ptr_offset_of_batch
,
const
Block2CTileMap
block_2_ctile_map
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx906__) || defined(__gfx908__) || \
defined(__gfx90a__) || defined(__gfx94__) || defined(__gfx103__) || defined(__gfx11__) || \
defined(__gfx12__))
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
...
...
@@ -648,7 +649,7 @@ struct DeviceBatchedGemmMultipleD_Dl : public DeviceBatchedGemmMultiD<ALayout,
static
bool
IsSupportedArgument
(
const
Argument
&
arg
)
{
if
(
ck
::
get_device_name
()
==
"gfx906"
||
ck
::
is_xdl_supported
()
||
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
())
ck
::
is_gfx103_supported
()
||
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
bool
pass
=
true
;
pass
=
pass
&&
arg
.
K_
%
K1
==
0
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
5ec6a912
...
...
@@ -56,7 +56,7 @@ __global__ void
bool
input_permute
,
bool
output_permute
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -159,6 +159,7 @@ __global__ void
ignore
=
O
;
ignore
=
G0
;
ignore
=
G1
;
ignore
=
alpha
;
ignore
=
input_permute
;
ignore
=
output_permute
;
#endif // end of if (defined(__gfx11__))
...
...
@@ -187,7 +188,7 @@ __global__ void
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -321,7 +322,7 @@ __global__ void
index_t
head_size
,
float
alpha
)
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx11__)
|| defined(__gfx12__)
)
// clang-format off
// ***************************************************
...
...
@@ -858,7 +859,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
bool
IsSupportedArgument
(
const
RawArg
&
arg
)
{
if
(
ck
::
is_gfx11_supported
())
if
(
ck
::
is_gfx11_supported
()
||
ck
::
is_gfx12_supported
()
)
{
if
constexpr
(
!
(
is_same_v
<
Acc0DataType
,
float
>
||
is_same_v
<
Acc0DataType
,
int32_t
>
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_column_to_image_impl.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -247,7 +247,8 @@ struct DeviceColumnToImageImpl
independent_filter_strides
,
conv_filter_dilations
,
input_left_pads_with_offset
,
input_right_pads
);
input_right_pads
,
N
);
const
auto
in_gemmm_gemmk_desc
=
matrix_padder
.
PadADescriptor_M_K
(
in_gemmmraw_gemmkraw_desc
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_abd_xdl_cshuffle.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -501,29 +501,24 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
// for sanity check of vector memory access
for
(
index_t
i
=
0
;
i
<
NumATensor
;
++
i
)
{
as_mz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
-
1
]
==
1
;
as_kz_consecutive_
[
i
]
=
a_ms_ks_strides
[
i
][
NumDimM
+
NumDimK
-
1
]
==
1
;
as_max_read_elems_
[
i
]
=
tie
(
as_continous_dim_
[
i
],
as_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
[
i
],
a_ms_ks_strides
[
i
]);
}
for
(
index_t
i
=
0
;
i
<
NumBTensor
;
++
i
)
{
bs_nz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
-
1
]
==
1
;
bs_kz_consecutive_
[
i
]
=
b_ns_ks_strides
[
i
][
NumDimN
+
NumDimK
-
1
]
==
1
;
bs_max_read_elems_
[
i
]
=
tie
(
bs_continous_dim_
[
i
],
bs_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
[
i
],
b_ns_ks_strides
[
i
]);
}
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_consecutive_
[
i
]
=
d_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
tie
(
ds_continous_dim_
[
i
],
ds_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
d_ms_ns_lengths
[
i
],
d_ms_ns_strides
[
i
]);
}
e_nz_consecutive_
=
e_ms_ns_stride
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
tie
(
e_continous_dim_
,
e_max_write_elems_
)
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_length
,
e_ms_ns_stride
);
}
// pointers
...
...
@@ -553,14 +548,11 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
std
::
array
<
bool
,
NumATensor
>
as_mz_consecutive_
;
std
::
array
<
bool
,
NumATensor
>
as_kz_consecutive_
;
std
::
array
<
bool
,
NumBTensor
>
bs_nz_consecutive_
;
std
::
array
<
bool
,
NumBTensor
>
bs_kz_consecutive_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
std
::
array
<
index_t
,
NumATensor
>
as_continous_dim_
;
std
::
array
<
index_t
,
NumATensor
>
bs_continous_dim_
;
std
::
array
<
index_t
,
NumBTensor
>
ds_continous_dim_
;
index_t
e_continous_dim_
;
std
::
array
<
index_t
,
NumATensor
>
as_max_read_elems_
;
std
::
array
<
index_t
,
NumBTensor
>
bs_max_read_elems_
;
...
...
@@ -659,9 +651,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_a_vector_size
=
arg
.
as_max_read_elems_
[
i
]
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_
mz_consecutive_
[
i
]
;
ABlockTransferSrcVectorDim
==
1
&&
arg
.
as_
continous_dim_
[
i
]
==
0
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_
kz_consecutive_
[
i
]
;
ABlockTransferSrcVectorDim
==
2
&&
arg
.
as_
continous_dim_
[
i
]
==
1
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
;
if
(
!
((
valid_a_vector_size
&&
valid_a_access_dim
)
||
ABlockTransferSrcScalarPerVector
==
1
))
...
...
@@ -679,9 +671,9 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_b_vector_size
=
arg
.
bs_max_read_elems_
[
i
]
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_
nz_consecutive_
[
i
]
;
BBlockTransferSrcVectorDim
==
1
&&
arg
.
bs_
continous_dim_
[
i
]
==
0
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_
kz_consecutive_
[
i
]
;
BBlockTransferSrcVectorDim
==
2
&&
arg
.
bs_
continous_dim_
[
i
]
==
1
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
;
if
(
!
((
valid_b_vector_size
&&
valid_b_access_dim
)
||
BBlockTransferSrcScalarPerVector
==
1
))
...
...
@@ -699,7 +691,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_d_vector_size
=
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_
nz_consecutive_
[
i
]
;
const
bool
valid_d_access_dim
=
arg
.
ds_
continous_dim_
[
i
]
==
1
;
if
(
!
((
valid_d_vector_size
&&
valid_d_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
...
...
@@ -714,7 +706,7 @@ struct DeviceContractionMultipleABD_Xdl_CShuffle
const
bool
valid_e_vector_size
=
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_
nz_consecutive_
;
const
bool
valid_e_access_dim
=
arg
.
e_
continous_dim_
==
1
;
if
(
!
((
valid_e_vector_size
&&
valid_e_access_dim
)
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
))
{
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_multiple_d_xdl_cshuffle.hpp
View file @
5ec6a912
...
...
@@ -442,25 +442,19 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
}
// for sanity check of vector memory access
a_mz_consecutive_
=
a_ms_ks_strides
[
NumDimM
-
1
]
==
1
;
a_kz_consecutive_
=
a_ms_ks_strides
[
NumDimM
+
NumDimK
-
1
]
==
1
;
a_max_read_elems_
=
tie
(
a_continous_dim_
,
a_max_read_elems_
)
=
CalculateMaxRead
<
NumDimM
,
NumDimK
>
(
a_ms_ks_lengths
,
a_ms_ks_strides
);
b_nz_consecutive_
=
b_ns_ks_strides
[
NumDimN
-
1
]
==
1
;
b_kz_consecutive_
=
b_ns_ks_strides
[
NumDimN
+
NumDimK
-
1
]
==
1
;
b_max_read_elems_
=
tie
(
b_continous_dim_
,
b_max_read_elems_
)
=
CalculateMaxRead
<
NumDimN
,
NumDimK
>
(
b_ns_ks_lengths
,
b_ns_ks_strides
);
for
(
index_t
i
=
0
;
i
<
NumDTensor
;
++
i
)
{
ds_nz_consecutive_
[
i
]
=
ds_ms_ns_strides
[
i
][
NumDimM
+
NumDimN
-
1
]
==
1
;
ds_max_read_elems_
[
i
]
=
tie
(
ds_continous_dim_
[
i
],
ds_max_read_elems_
[
i
])
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
ds_ms_ns_lengths
[
i
],
ds_ms_ns_strides
[
i
]);
}
e_nz_consecutive_
=
e_ms_ns_strides
[
NumDimM
+
NumDimN
-
1
]
==
1
;
e_max_write_elems_
=
tie
(
e_continous_dim_
,
e_max_write_elems_
)
=
CalculateMaxRead
<
NumDimM
,
NumDimN
>
(
e_ms_ns_lengths
,
e_ms_ns_strides
);
}
...
...
@@ -501,14 +495,11 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
BElementwiseOperation
b_element_op_
;
CDEElementwiseOperation
cde_element_op_
;
// Describe whether the last part of a given dimension of A/B/D/E is consecutive
// in the memory or not.
bool
a_mz_consecutive_
;
bool
a_kz_consecutive_
;
bool
b_nz_consecutive_
;
bool
b_kz_consecutive_
;
std
::
array
<
bool
,
NumDTensor
>
ds_nz_consecutive_
;
bool
e_nz_consecutive_
;
// Describe whether the last part of a given dimension of A/B/D/E is continues dim.
index_t
a_continous_dim_
;
index_t
b_continous_dim_
;
std
::
array
<
index_t
,
NumDTensor
>
ds_continous_dim_
;
index_t
e_continous_dim_
;
index_t
a_max_read_elems_
;
index_t
b_max_read_elems_
;
...
...
@@ -601,9 +592,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
return
false
;
}
if
(
ck
::
get_device_name
()
!=
"gfx90a"
&&
ck
::
get_device_name
()
!=
"gfx940"
&&
ck
::
get_device_name
()
!=
"gfx941"
&&
ck
::
get_device_name
()
!=
"gfx942"
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
if
(
!
ck
::
is_lds_direct_load_supported
()
&&
std
::
is_same
<
ADataType
,
double
>::
value
)
{
return
false
;
}
...
...
@@ -624,8 +613,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_a_vector_size
=
arg
.
a_max_read_elems_
%
ABlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_mz_consecutive_
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_kz_consecutive_
;
const
bool
valid_a_access_dim_m
=
ABlockTransferSrcVectorDim
==
1
&&
arg
.
a_continous_dim_
==
0
;
const
bool
valid_a_access_dim_k
=
ABlockTransferSrcVectorDim
==
2
&&
arg
.
a_continous_dim_
==
1
;
const
bool
valid_a_access_dim
=
valid_a_access_dim_m
||
valid_a_access_dim_k
||
ABlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_a_vector_size
&&
valid_a_access_dim
))
...
...
@@ -635,8 +626,10 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
const
bool
valid_b_vector_size
=
arg
.
b_max_read_elems_
%
BBlockTransferSrcScalarPerVector
==
0
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_nz_consecutive_
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_kz_consecutive_
;
const
bool
valid_b_access_dim_n
=
BBlockTransferSrcVectorDim
==
1
&&
arg
.
b_continous_dim_
==
0
;
const
bool
valid_b_access_dim_k
=
BBlockTransferSrcVectorDim
==
2
&&
arg
.
b_continous_dim_
==
1
;
const
bool
valid_b_access_dim
=
valid_b_access_dim_n
||
valid_b_access_dim_k
||
BBlockTransferSrcScalarPerVector
==
1
;
if
(
!
(
valid_b_vector_size
&&
valid_b_access_dim
))
...
...
@@ -650,7 +643,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
ds_max_read_elems_
[
i
]
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector read of Ds is always on N dimension.
const
bool
valid_d_access_dim
=
arg
.
ds_
nz_consecutive_
[
i
]
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
arg
.
ds_
continous_dim_
[
i
]
==
1
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_d_vector_size
&&
valid_d_access_dim
))
{
valid_ds_access
=
false
;
...
...
@@ -665,7 +658,7 @@ struct DeviceContractionMultipleD_Xdl_CShuffle
arg
.
e_max_write_elems_
%
CDEBlockTransferScalarPerVector_NPerBlock
==
0
;
// Vector write of E is always on N dimension.
const
bool
valid_e_access_dim
=
arg
.
e_
nz_consecutive_
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
arg
.
e_
continous_dim_
==
1
||
CDEBlockTransferScalarPerVector_NPerBlock
==
1
;
if
(
!
(
valid_e_vector_size
&&
valid_e_access_dim
))
{
return
false
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_contraction_utils.hpp
View file @
5ec6a912
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023
-2024
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -50,25 +50,53 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
// Determine the beginning and end idx of the group representing the FCD.
index_t
begin_idx
,
end_idx
;
if
(
strides
[
NumDim1
-
1
]
==
1
)
index_t
begin_idx
,
end_idx
,
continous_dim
,
consecutive_stride
=
1
;
if
(
strides
[
NumDim1
-
1
]
==
1
&&
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
// MZ or KZ are ones
bool
dims1_are_ones
=
true
;
for
(
index_t
dim_idx
=
0
;
dim_idx
<
NumDim1
;
dim_idx
++
)
{
if
(
lengths
[
dim_idx
]
!=
1
)
{
dims1_are_ones
=
false
;
}
}
if
(
dims1_are_ones
)
{
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
continous_dim
=
1
;
}
else
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
continous_dim
=
0
;
}
}
else
if
(
strides
[
NumDim1
-
1
]
==
1
)
{
begin_idx
=
0
;
end_idx
=
NumDim1
-
1
;
continous_dim
=
0
;
}
else
if
(
strides
[
NumDim1
+
NumDim2
-
1
]
==
1
)
{
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
begin_idx
=
NumDim1
;
end_idx
=
NumDim1
+
NumDim2
-
1
;
continous_dim
=
1
;
}
else
{
// The dimension consecutive in memory is not the last dimension of any group, so only
// one element can be read/written at once.
return
1
;
consecutive_stride
=
1
;
continous_dim
=
0
;
return
make_tuple
(
continous_dim
,
consecutive_stride
);
}
index_t
consecutive_stride
=
1
;
for
(
index_t
dim_idx
=
end_idx
;
dim_idx
>=
begin_idx
;
--
dim_idx
)
{
if
(
strides
[
dim_idx
]
==
consecutive_stride
)
...
...
@@ -81,7 +109,7 @@ auto CalculateMaxRead(const std::vector<index_t>& lengths, const std::vector<ind
}
}
const
index_t
max_subsequent_elems
=
consecutive_stride
;
return
max_subsequent_elems
;
return
make_tuple
(
continous_dim
,
max_subsequent_elems
)
;
}
}
// namespace device
...
...
Prev
1
2
3
4
5
6
7
8
9
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment