Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
69b6d2ab
Commit
69b6d2ab
authored
Jan 21, 2025
by
Adam Osewski
Browse files
Enable reading on contiguous dimension in all layouts.
parent
bd5008af
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
38 deletions
+72
-38
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+46
-22
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
...ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
+26
-16
No files found.
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
69b6d2ab
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -8,7 +8,6 @@
...
@@ -8,7 +8,6 @@
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -141,7 +140,9 @@ struct GemmKernel
...
@@ -141,7 +140,9 @@ struct GemmKernel
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
GemmCommonKargs
kargs
)
const
{
{
const
auto
[
i_m
,
i_n
]
=
TilePartitioner
{}();
const
auto
idxs
=
TilePartitioner
{}();
const
auto
i_m
=
idxs
.
at
(
number
<
0
>
{});
const
auto
i_n
=
idxs
.
at
(
number
<
1
>
{});
// options
// options
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
ADataType
*
a_start
=
static_cast
<
const
ADataType
*>
(
kargs
.
a_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
const
BDataType
*
b_start
=
static_cast
<
const
BDataType
*>
(
kargs
.
b_ptr
);
...
@@ -160,9 +161,9 @@ struct GemmKernel
...
@@ -160,9 +161,9 @@ struct GemmKernel
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
a_start
,
a_start
,
make_tuple
(
kargs
.
M
,
kargs
.
K
),
make_tuple
(
kargs
.
K
,
kargs
.
M
),
make_tuple
(
1
,
kargs
.
stride_A
),
make_tuple
(
kargs
.
stride_A
,
1
),
number
<
1
>
{},
number
<
GemmPipeline
::
VectorSizeA
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
}();
}();
...
@@ -172,9 +173,9 @@ struct GemmKernel
...
@@ -172,9 +173,9 @@ struct GemmKernel
{
{
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
return
make_naive_tensor_view
<
address_space_enum
::
global
>
(
b_start
,
b_start
,
make_tuple
(
kargs
.
N
,
kargs
.
K
),
make_tuple
(
kargs
.
K
,
kargs
.
N
),
make_tuple
(
1
,
kargs
.
stride_B
),
make_tuple
(
kargs
.
stride_B
,
1
),
number
<
1
>
{},
number
<
GemmPipeline
::
VectorSizeB
>
{},
number
<
1
>
{});
number
<
1
>
{});
}
}
else
else
...
@@ -200,16 +201,27 @@ struct GemmKernel
...
@@ -200,16 +201,27 @@ struct GemmKernel
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
a_tensor_view
,
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
k
M
>
{},
number
<
TilePartitioner
::
k
K
>
{}),
make_tuple
(
number
<
TilePartitioner
::
k
K
>
{},
number
<
TilePartitioner
::
k
M
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
sequence
<
false
,
GemmPipeline
::
kPadM
>
{});
}
}
}();
}();
// clang-format on
auto
a_block_window
=
make_tile_window
(
auto
a_block_window
=
[
&
]()
{
a_pad_view
,
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
{
i_m
,
0
});
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
}
else
{
return
make_tile_window
(
a_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kK
>
{},
number
<
TilePartitioner
::
kM
>
{}),
{
0
,
i_m
});
}
}();
auto
b_pad_view
=
[
&
]()
{
auto
b_pad_view
=
[
&
]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
...
@@ -223,15 +235,27 @@ struct GemmKernel
...
@@ -223,15 +235,27 @@ struct GemmKernel
{
{
return
pad_tensor_view
(
return
pad_tensor_view
(
b_tensor_view
,
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
k
N
>
{},
number
<
TilePartitioner
::
k
K
>
{}),
make_tuple
(
number
<
TilePartitioner
::
k
K
>
{},
number
<
TilePartitioner
::
k
N
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
}
}();
}();
auto
b_block_window
=
make_tile_window
(
auto
b_block_window
=
[
&
]()
{
b_pad_view
,
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
{
i_n
,
0
});
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_n
,
0
});
}
else
{
return
make_tile_window
(
b_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kK
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
0
,
i_n
});
}
}();
// allocate LDS
// allocate LDS
__shared__
char
smem_ptr
[
GetSmemSize
()];
__shared__
char
smem_ptr
[
GetSmemSize
()];
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_base.hpp
View file @
69b6d2ab
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -12,18 +13,21 @@ struct GemmPipelineAgBgCrImplBase
...
@@ -12,18 +13,21 @@ struct GemmPipelineAgBgCrImplBase
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
using
BlockGemmShape
=
remove_cvref_t
<
typename
Problem
::
BlockGemmShape
>
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
MPerBlock
=
BlockGemmShape
::
kM
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
NPerBlock
=
BlockGemmShape
::
kN
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
static
constexpr
index_t
KPerBlock
=
BlockGemmShape
::
kK
;
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
>
template
<
typename
DstBlockTile
,
typename
SrcTileWindow
,
typename
DramTileWindowStep
>
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
CK_TILE_DEVICE
void
GlobalPrefetch
(
DstBlockTile
&
dst_block_tile
,
SrcTileWindow
&
dram_tile_window
)
const
SrcTileWindow
&
dram_tile_window
,
const
DramTileWindowStep
&
dram_tile_window_step
)
const
{
{
load_tile
(
dst_block_tile
,
dram_tile_window
);
load_tile
(
dst_block_tile
,
dram_tile_window
);
move_tile_window
(
dram_tile_window
,
{
0
,
KPerBlock
}
);
move_tile_window
(
dram_tile_window
,
dram_tile_window_step
);
}
}
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
template
<
typename
DstTileWindow
,
typename
SrcBlockTile
,
typename
ElementFunction
>
...
@@ -60,19 +64,21 @@ struct GemmPipelineAgBgCrImplBase
...
@@ -60,19 +64,21 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE
auto
GetAWindows
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
CK_TILE_DEVICE
auto
GetAWindows
(
const
ADramBlockWindowTmp
&
a_dram_block_window_tmp
,
const
ALdsTensorView
&
a_lds_block_view
)
const
const
ALdsTensorView
&
a_lds_block_view
)
const
{
{
constexpr
bool
is_col_major
=
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
;
using
YPerTile
=
std
::
conditional_t
<
is_col_major
,
number
<
KPerBlock
>
,
number
<
MPerBlock
>>
;
using
XPerTile
=
std
::
conditional_t
<
is_col_major
,
number
<
MPerBlock
>
,
number
<
KPerBlock
>>
;
// A DRAM tile window for load
// A DRAM tile window for load
auto
a_copy_dram_window
=
auto
a_copy_dram_window
=
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
a_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
YPerTile
{},
XPerTile
{}),
a_dram_block_window_tmp
.
get_window_origin
(),
a_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
// A LDS tile window for store
auto
a_copy_lds_window
=
auto
a_copy_lds_window
=
make_tile_window
(
make_tile_window
(
a_lds_block_view
,
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
auto
a_lds_gemm_window
=
make_tile_window
(
auto
a_lds_gemm_window
=
make_tile_window
(
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
a_lds_block_view
,
make_tuple
(
number
<
MPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
...
@@ -86,18 +92,22 @@ struct GemmPipelineAgBgCrImplBase
...
@@ -86,18 +92,22 @@ struct GemmPipelineAgBgCrImplBase
CK_TILE_DEVICE
auto
GetBWindows
(
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
CK_TILE_DEVICE
auto
GetBWindows
(
const
BDramBlockWindowTmp
&
b_dram_block_window_tmp
,
const
BLdsTensorView
&
b_lds_block_view
)
const
const
BLdsTensorView
&
b_lds_block_view
)
const
{
{
constexpr
bool
is_row_major
=
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
;
using
YPerTile
=
std
::
conditional_t
<
is_row_major
,
number
<
KPerBlock
>
,
number
<
NPerBlock
>>
;
using
XPerTile
=
std
::
conditional_t
<
is_row_major
,
number
<
NPerBlock
>
,
number
<
KPerBlock
>>
;
auto
b_copy_dram_window
=
auto
b_copy_dram_window
=
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tile_window
(
b_dram_block_window_tmp
.
get_bottom_tensor_view
(),
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
make_tuple
(
YPerTile
{},
XPerTile
{}),
b_dram_block_window_tmp
.
get_window_origin
(),
b_dram_block_window_tmp
.
get_window_origin
(),
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// TODO: Do we really need those two tile windows???
// They're exactly same...
// B LDS tile window for store
// B LDS tile window for store
auto
b_copy_lds_window
=
auto
b_copy_lds_window
=
make_tile_window
(
make_tile_window
(
b_lds_block_view
,
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
auto
b_lds_gemm_window
=
make_tile_window
(
auto
b_lds_gemm_window
=
make_tile_window
(
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
b_lds_block_view
,
make_tuple
(
number
<
NPerBlock
>
{},
number
<
KPerBlock
>
{}),
{
0
,
0
});
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment