Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
20eb3b68
Commit
20eb3b68
authored
Dec 16, 2024
by
Jing Zhang
Browse files
fixed
parent
2af8f32a
Changes
5
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
76 additions
and
73 deletions
+76
-73
CMakeLists.txt
CMakeLists.txt
+1
-1
include/ck/tensor/static_tensor.hpp
include/ck/tensor/static_tensor.hpp
+10
-2
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+2
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
.../block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
+62
-62
profiler/include/profiler/profile_gemm_universal_impl.hpp
profiler/include/profiler/profile_gemm_universal_impl.hpp
+1
-7
No files found.
CMakeLists.txt
View file @
20eb3b68
...
@@ -581,7 +581,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
...
@@ -581,7 +581,7 @@ if(NOT GPU_ARCHS AND USER_GPU_TARGETS)
)
)
add_subdirectory
(
example
)
add_subdirectory
(
example
)
if
(
BUILD_TESTING
)
if
(
BUILD_TESTING
)
add_subdirectory
(
test
)
add_subdirectory
(
test
)
endif
()
endif
()
endif
()
endif
()
...
...
include/ck/tensor/static_tensor.hpp
View file @
20eb3b68
...
@@ -165,7 +165,11 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -165,7 +165,11 @@ struct StaticTensorTupleOfVectorBuffer
// Get X
// Get X
// Idx is for S, not X. Idx should be aligned with X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
>
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
__host__
__device__
constexpr
X
GetAsType
(
Idx
)
const
{
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
...
@@ -195,7 +199,11 @@ struct StaticTensorTupleOfVectorBuffer
...
@@ -195,7 +199,11 @@ struct StaticTensorTupleOfVectorBuffer
// Set X
// Set X
// Idx is for S, not X. Idx should be aligned with X
// Idx is for S, not X. Idx should be aligned with X
template
<
typename
X
,
typename
Idx
>
template
<
typename
X
,
typename
Idx
,
typename
enable_if
<
(
has_same_scalar_type
<
S
,
X
>
::
value
||
!
is_native_type
<
S
>
())
&&
is_known_at_compile_time
<
Idx
>::
value
&&
Idx
::
Size
()
==
ndim_
,
bool
>::
type
=
false
>
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
__host__
__device__
constexpr
void
SetAsType
(
Idx
,
X
x
)
{
{
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
constexpr
auto
coord
=
make_tensor_coordinate
(
desc_
,
to_multi_index
(
Idx
{}));
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
20eb3b68
...
@@ -407,7 +407,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
...
@@ -407,7 +407,8 @@ struct GridwiseGemm_xdl_cshuffle_v3
}
}
else
else
{
{
// Weight Tile Permute
// Pre-shuffled Weight
// BGlobal[K / KPerBlock, N, KPerBlock / K1, K1] -> BTile[K / K1, N, K1]
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
constexpr
index_t
BK01
=
KPerBlock
/
BK1Value
;
// const index_t BK00 = BK0 / BK01;
// const index_t BK00 = BK0 / BK01;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
const
index_t
BK0_
=
StrideB
/
BK1Value
;
...
...
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
View file @
20eb3b68
This diff is collapsed.
Click to expand it.
profiler/include/profiler/profile_gemm_universal_impl.hpp
View file @
20eb3b68
...
@@ -230,13 +230,7 @@ bool profile_gemm_universal_impl(int do_verification,
...
@@ -230,13 +230,7 @@ bool profile_gemm_universal_impl(int do_verification,
}
}
else
else
{
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
{
for
(
int
j
=
0
;
j
<
K
;
j
++
)
{
b_k_n_permute
(
i
*
K
+
j
)
=
b_k_n
(
i
*
K
+
j
);
}
}
}
}
b_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
b_device_buf
.
ToDevice
(
b_k_n_permute
.
mData
.
data
());
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment