Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
28699402
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "509e2dec8a1f475e9b54c1121835feb22c6fd005"
Commit
28699402
authored
Nov 20, 2023
by
Bartlomiej Wroblewski
Browse files
Review: Apply review suggestions
parent
b6e14520
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
36 additions
and
24 deletions
+36
-24
example/01_gemm/CMakeLists.txt
example/01_gemm/CMakeLists.txt
+12
-7
example/04_gemm_add_add_fastgelu/CMakeLists.txt
example/04_gemm_add_add_fastgelu/CMakeLists.txt
+12
-5
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp
...stgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp
+1
-0
include/ck/host_utility/device_prop.hpp
include/ck/host_utility/device_prop.hpp
+2
-1
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
.../block/thread_group_tensor_slice_transfer_direct_load.hpp
+2
-5
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
...gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
+2
-4
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
...or_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
+1
-0
include/ck/utility/dynamic_buffer.hpp
include/ck/utility/dynamic_buffer.hpp
+4
-2
No files found.
example/01_gemm/CMakeLists.txt
View file @
28699402
...
@@ -56,13 +56,18 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
...
@@ -56,13 +56,18 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_executable
(
example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp8_bf8
)
if
(
GPU_TARGETS MATCHES
"gfx90a"
OR GPU_TARGETS MATCHES
"gfx940"
OR GPU_TARGETS MATCHES
"gfx941"
OR GPU_TARGETS MATCHES
"gfx942"
)
list
(
APPEND gpu_list gfx90a gfx940 gfx941 gfx942
)
add_example_executable
(
example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp
)
set
(
target 0
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_example_executable
(
example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp
)
add_example_executable
(
example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32
)
endif
()
add_example_executable
(
example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16
)
set
(
target 1
)
endif
()
endforeach
()
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_executable
(
example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
add_example_dependencies
(
example_gemm_xdl example_gemm_xdl_fp16_fp8
)
example/04_gemm_add_add_fastgelu/CMakeLists.txt
View file @
28699402
...
@@ -12,11 +12,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -12,11 +12,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32
)
if
(
GPU_TARGETS MATCHES
"gfx90a"
OR GPU_TARGETS MATCHES
"gfx940"
OR GPU_TARGETS MATCHES
"gfx941"
OR GPU_TARGETS MATCHES
"gfx942"
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32
)
endif
()
if
(
USE_BITINT_EXTENSION_INT4
)
if
(
USE_BITINT_EXTENSION_INT4
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4
)
...
@@ -27,3 +22,15 @@ foreach(gpu IN LISTS GPU_TARGETS)
...
@@ -27,3 +22,15 @@ foreach(gpu IN LISTS GPU_TARGETS)
set
(
target 1
)
set
(
target 1
)
endif
()
endif
()
endforeach
()
endforeach
()
set
(
gpu_list
""
)
list
(
APPEND gpu_list gfx90a gfx940 gfx941 gfx942
)
set
(
target 0
)
foreach
(
gpu IN LISTS GPU_TARGETS
)
if
(
gpu IN_LIST gpu_list AND target EQUAL 0
)
add_example_executable
(
example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp
)
add_example_dependencies
(
example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32
)
set
(
target 1
)
endif
()
endforeach
()
\ No newline at end of file
example/04_gemm_add_add_fastgelu/gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp
View file @
28699402
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
#include "common.hpp"
...
...
include/ck/host_utility/device_prop.hpp
View file @
28699402
...
@@ -58,8 +58,9 @@ inline bool is_xdl_supported()
...
@@ -58,8 +58,9 @@ inline bool is_xdl_supported()
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx942"
;
}
}
inline
bool
is_direct_load_supported
()
inline
bool
is_
lds_
direct_load_supported
()
{
{
// Check if direct loads from global memory to LDS are supported.
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
return
ck
::
get_device_name
()
==
"gfx90a"
||
ck
::
get_device_name
()
==
"gfx940"
||
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
ck
::
get_device_name
()
==
"gfx941"
||
ck
::
get_device_name
()
==
"gfx942"
;
}
}
...
...
include/ck/tensor_operation/gpu/block/thread_group_tensor_slice_transfer_direct_load.hpp
View file @
28699402
...
@@ -38,8 +38,6 @@ namespace ck {
...
@@ -38,8 +38,6 @@ namespace ck {
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* to guarantee that.
* to guarantee that.
*
* For now, only single LDS buffer is supported.
*/
*/
template
<
typename
ThreadGroup
,
template
<
typename
ThreadGroup
,
typename
BlockSliceLengths
,
typename
BlockSliceLengths
,
...
@@ -50,8 +48,7 @@ template <typename ThreadGroup,
...
@@ -50,8 +48,7 @@ template <typename ThreadGroup,
typename
DstDesc
,
typename
DstDesc
,
index_t
SrcVectorDim
,
index_t
SrcVectorDim
,
index_t
DstVectorDim
,
index_t
DstVectorDim
,
index_t
ScalarPerVector
,
index_t
ScalarPerVector
>
index_t
NumLdsBuffers
=
1
>
struct
ThreadGroupTensorSliceTransfer_DirectLoad
struct
ThreadGroupTensorSliceTransfer_DirectLoad
{
{
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
static
constexpr
index_t
nDim
=
remove_reference_t
<
SrcDesc
>::
GetNumOfDimension
();
...
@@ -227,7 +224,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
...
@@ -227,7 +224,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const
bool
is_src_valid
=
const
bool
is_src_valid
=
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
coordinate_has_valid_offset_assuming_visible_index_is_valid
(
src_desc
,
src_coord_
);
src_buf
.
template
CopyTo
<
remove_cvref_t
<
decltype
(
dst_buf
)>,
ScalarPerVector
>
(
src_buf
.
template
Direct
CopyTo
Lds
<
remove_cvref_t
<
decltype
(
dst_buf
)>,
ScalarPerVector
>
(
dst_buf
,
src_offset
,
dst_offset
,
is_src_valid
);
dst_buf
,
src_offset
,
dst_offset
,
is_src_valid
);
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
constexpr
auto
move_on_dim
=
[
&
]()
constexpr
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_xdl_cshuffle_lds_direct_load.hpp
View file @
28699402
...
@@ -571,8 +571,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -571,8 +571,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype
(
a_block_desc_ak0_m_ak1
),
decltype
(
a_block_desc_ak0_m_ak1
),
ABlockTransferSrcVectorDim
,
ABlockTransferSrcVectorDim
,
2
,
2
,
ABlockTransferScalarPerVector
,
ABlockTransferScalarPerVector
>
(
NumGemmKPrefetchStage
>
(
a_grid_desc_ak0_m_ak1
,
a_grid_desc_ak0_m_ak1
,
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
m_block_data_idx_on_grid
,
0
),
a_block_desc_ak0_m_ak1
,
a_block_desc_ak0_m_ak1
,
...
@@ -588,8 +587,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
...
@@ -588,8 +587,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype
(
b_block_desc_bk0_n_bk1
),
decltype
(
b_block_desc_bk0_n_bk1
),
BBlockTransferSrcVectorDim
,
BBlockTransferSrcVectorDim
,
2
,
2
,
BBlockTransferScalarPerVector
,
BBlockTransferScalarPerVector
>
(
NumGemmKPrefetchStage
>
(
b_grid_desc_bk0_n_bk1
,
b_grid_desc_bk0_n_bk1
,
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
make_multi_index
(
0
,
n_block_data_idx_on_grid
,
0
),
b_block_desc_bk0_n_bk1
,
b_block_desc_bk0_n_bk1
,
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_pipeline_selector.hpp
View file @
28699402
...
@@ -15,6 +15,7 @@ enum struct PipelineVersion
...
@@ -15,6 +15,7 @@ enum struct PipelineVersion
{
{
v1
,
v1
,
v2
,
v2
,
// v3 is only used in the Stream-K implementation.
v4
,
v4
,
};
};
...
...
include/ck/utility/dynamic_buffer.hpp
View file @
28699402
...
@@ -174,8 +174,10 @@ struct DynamicBuffer
...
@@ -174,8 +174,10 @@ struct DynamicBuffer
}
}
template
<
typename
DstBuffer
,
index_t
NumElemsPerThread
>
template
<
typename
DstBuffer
,
index_t
NumElemsPerThread
>
__host__
__device__
void
__host__
__device__
void
DirectCopyToLds
(
DstBuffer
&
dst_buf
,
CopyTo
(
DstBuffer
&
dst_buf
,
index_t
src_offset
,
index_t
dst_offset
,
bool
is_valid_element
)
const
index_t
src_offset
,
index_t
dst_offset
,
bool
is_valid_element
)
const
{
{
// Copy data from global to LDS memory using direct loads.
// Copy data from global to LDS memory using direct loads.
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
static_assert
(
GetAddressSpace
()
==
AddressSpaceEnum
::
Global
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment