Commit 28699402 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Apply review suggestions

parent b6e14520
......@@ -56,13 +56,18 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if(GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
endif()
set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
......@@ -12,11 +12,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
if(GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
add_example_executable(example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
endif()
if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
......@@ -27,3 +22,15 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1)
endif()
endforeach()
set(gpu_list "")
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp"
......
......@@ -58,8 +58,9 @@ inline bool is_xdl_supported()
ck::get_device_name() == "gfx942";
}
inline bool is_direct_load_supported()
inline bool is_lds_direct_load_supported()
{
// Check if direct loads from global memory to LDS are supported.
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
}
......
......@@ -38,8 +38,6 @@ namespace ck {
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* to guarantee that.
*
* For now, only single LDS buffer is supported.
*/
template <typename ThreadGroup,
typename BlockSliceLengths,
......@@ -50,8 +48,7 @@ template <typename ThreadGroup,
typename DstDesc,
index_t SrcVectorDim,
index_t DstVectorDim,
index_t ScalarPerVector,
index_t NumLdsBuffers = 1>
index_t ScalarPerVector>
struct ThreadGroupTensorSliceTransfer_DirectLoad
{
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
......@@ -227,7 +224,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_buf.template CopyTo<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, is_src_valid);
constexpr auto move_on_dim = [&]() constexpr
......
......@@ -571,8 +571,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcVectorDim,
2,
ABlockTransferScalarPerVector,
NumGemmKPrefetchStage>(
ABlockTransferScalarPerVector>(
a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1,
......@@ -588,8 +587,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcVectorDim,
2,
BBlockTransferScalarPerVector,
NumGemmKPrefetchStage>(
BBlockTransferScalarPerVector>(
b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1,
......
......@@ -15,6 +15,7 @@ enum struct PipelineVersion
{
v1,
v2,
// v3 is only used in the Stream-K implementation.
v4,
};
......
......@@ -174,8 +174,10 @@ struct DynamicBuffer
}
template <typename DstBuffer, index_t NumElemsPerThread>
__host__ __device__ void
CopyTo(DstBuffer& dst_buf, index_t src_offset, index_t dst_offset, bool is_valid_element) const
__host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
index_t src_offset,
index_t dst_offset,
bool is_valid_element) const
{
// Copy data from global to LDS memory using direct loads.
static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment