Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7e3a5613
Commit
7e3a5613
authored
Apr 07, 2024
by
Jing Zhang
Browse files
clean up
parent
50530c17
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
15 additions
and
18 deletions
+15
-18
example/30_grouped_conv_fwd_multiple_d/common.hpp
example/30_grouped_conv_fwd_multiple_d/common.hpp
+4
-4
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
+4
-4
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
...gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
+3
-3
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
...ude/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
+0
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+2
-2
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
...evice_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
+1
-1
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
...vice/impl/device_grouped_query_attention_forward_wmma.hpp
+1
-1
No files found.
example/30_grouped_conv_fwd_multiple_d/common.hpp
View file @
7e3a5613
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
};
};
#define DefaultConvParam \
#define DefaultConvParam
\
ck::utils::conv::ConvParam \
ck::utils::conv::ConvParam
\
{ \
{
\
2, 32, 2,
32, 3
2, {3, 3}, {
14, 14
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
2, 32, 2,
256, 19
2, {3, 3}, {
71, 71
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
}
inline
void
print_help_msg
()
inline
void
print_help_msg
()
...
...
example/30_grouped_conv_fwd_multiple_d/common_wmma.hpp
View file @
7e3a5613
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
...
@@ -90,10 +90,10 @@ struct ExecutionConfig final
bool
time_kernel
=
true
;
bool
time_kernel
=
true
;
};
};
#define DefaultConvParam \
#define DefaultConvParam
\
ck::utils::conv::ConvParam \
ck::utils::conv::ConvParam
\
{ \
{
\
2, 32, 2,
32, 3
2, {3, 3}, {
14, 14
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
2, 32, 2,
256, 19
2, {3, 3}, {
71, 71
}, {2, 2}, {1, 1}, {1, 1}, { 1, 1 } \
}
}
inline
void
print_help_msg
()
inline
void
print_help_msg
()
...
...
example/32_batched_gemm_scale_softmax_gemm/run_batched_gemm_scale_softmax_gemm_permute_wmma.inc
View file @
7e3a5613
...
@@ -9,10 +9,10 @@ int run(int argc, char* argv[])
...
@@ -9,10 +9,10 @@ int run(int argc, char* argv[])
// GEMM shape for A/B0/B1/C
// GEMM shape for A/B0/B1/C
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
// C_g_m_o = A_g_m_k * B0_g_k_n * B1_g_n_o
ck
::
index_t
M
=
6
0
;
ck
::
index_t
M
=
12
0
;
ck
::
index_t
N
=
100
;
ck
::
index_t
N
=
100
0
;
ck
::
index_t
K
=
64
;
ck
::
index_t
K
=
64
;
ck
::
index_t
O
=
64
;
ck
::
index_t
O
=
128
;
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// Output shape C[G0, M, G1, O]. Batch dim, outer dim, inner dim must match GEMM shape
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
// C_g0_g1_m_o = reshape(C_g_m_o, [g0, g1, m, o])
...
...
include/ck/tensor_operation/gpu/block/blockwise_gemm_wmma.hpp
View file @
7e3a5613
...
@@ -194,9 +194,6 @@ struct BlockwiseGemmWMMA
...
@@ -194,9 +194,6 @@ struct BlockwiseGemmWMMA
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
static_assert
(
MPerBlock
%
(
MPerWMMA
*
MRepeat
)
==
0
&&
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
NPerBlock
%
(
NPerWMMA
*
NRepeat
)
==
0
,
"wrong!"
);
"wrong!"
);
// static_assert(AEnableLds == true, "only support EnableLds");
// static_assert(BEnableLds == true, "only support EnableLds");
}
}
// transposed WMMA output C' = B' * A'
// transposed WMMA output C' = B' * A'
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
7e3a5613
...
@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -137,8 +137,8 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
BEnableLds_auto
=
MWaves
==
1
?
false
:
true
;
// If true, LDS is used unconditionally
// If true, LDS is used unconditionally
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
BEnableLds_manu
=
tru
e
;
static
constexpr
auto
BEnableLds_manu
=
fals
e
;
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
AEnableLds
=
AEnableLds_auto
||
AEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
static
constexpr
auto
BEnableLds
=
BEnableLds_auto
||
BEnableLds_manu
||
(
NumPrefetch
>
1
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_gemm_softmax_gemm_permute_wmma_cshuffle.hpp
View file @
7e3a5613
...
@@ -562,7 +562,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
...
@@ -562,7 +562,7 @@ struct DeviceBatchedGemmSoftmaxGemmPermute_Wmma_CShuffle
static
constexpr
auto
B0EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B0EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B1EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B1EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
...
...
include/ck/tensor_operation/gpu/device/impl/device_grouped_query_attention_forward_wmma.hpp
View file @
7e3a5613
...
@@ -300,7 +300,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
...
@@ -300,7 +300,7 @@ struct DeviceGroupedQueryAttentionForward_Wmma
static
constexpr
auto
B0EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B0EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B1EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
B1EnableLds_auto
=
MWaves
==
1
?
false
:
true
;
static
constexpr
auto
AEnableLds_manu
=
tru
e
;
static
constexpr
auto
AEnableLds_manu
=
fals
e
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B0EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
static
constexpr
auto
B1EnableLds_manu
=
true
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment