"vscode:/vscode.git/clone" did not exist on "154e8c468ec9be542bc6a171c45ddb0be185184d"
Commit 28699402 authored by Bartlomiej Wroblewski's avatar Bartlomiej Wroblewski
Browse files

Review: Apply review suggestions

parent b6e14520
...@@ -56,13 +56,18 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8) ...@@ -56,13 +56,18 @@ add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8)
add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp) add_example_executable(example_gemm_xdl_fp8_bf8 gemm_xdl_fp8_bf8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp8_bf8)
if(GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942") list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp) add_example_executable(example_gemm_xdl_lds_direct_load_fp32 gemm_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp32)
add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp) add_example_executable(example_gemm_xdl_lds_direct_load_fp16 gemm_xdl_lds_direct_load_fp16.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16) add_example_dependencies(example_gemm_xdl example_gemm_xdl_lds_direct_load_fp16)
endif() set(target 1)
endif()
endforeach()
add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp) add_example_executable(example_gemm_xdl_fp16_fp8 gemm_xdl_fp16_fp8.cpp)
add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8) add_example_dependencies(example_gemm_xdl example_gemm_xdl_fp16_fp8)
...@@ -12,11 +12,6 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -12,11 +12,6 @@ foreach(gpu IN LISTS GPU_TARGETS)
add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_fp32 gemm_add_add_fastgelu_xdl_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_fp32)
if(GPU_TARGETS MATCHES "gfx90a" OR GPU_TARGETS MATCHES "gfx940" OR GPU_TARGETS MATCHES "gfx941" OR GPU_TARGETS MATCHES "gfx942")
add_example_executable(example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
endif()
if(USE_BITINT_EXTENSION_INT4) if(USE_BITINT_EXTENSION_INT4)
add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp) add_example_executable(example_gemm_add_add_fastgelu_xdl_int4 gemm_add_add_fastgelu_xdl_int4.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4) add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_int4)
...@@ -27,3 +22,15 @@ foreach(gpu IN LISTS GPU_TARGETS) ...@@ -27,3 +22,15 @@ foreach(gpu IN LISTS GPU_TARGETS)
set(target 1) set(target 1)
endif() endif()
endforeach() endforeach()
set(gpu_list "")
list(APPEND gpu_list gfx90a gfx940 gfx941 gfx942)
set(target 0)
foreach(gpu IN LISTS GPU_TARGETS)
if(gpu IN_LIST gpu_list AND target EQUAL 0)
add_example_executable(example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32 gemm_add_add_fastgelu_xdl_lds_direct_load_fp32.cpp)
add_example_dependencies(example_gemm_add_add_fastgelu_xdl example_gemm_add_add_fastgelu_xdl_lds_direct_load_fp32)
set(target 1)
endif()
endforeach()
\ No newline at end of file
// SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved.
#include "common.hpp" #include "common.hpp"
......
...@@ -58,8 +58,9 @@ inline bool is_xdl_supported() ...@@ -58,8 +58,9 @@ inline bool is_xdl_supported()
ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx942";
} }
inline bool is_direct_load_supported() inline bool is_lds_direct_load_supported()
{ {
// Check if direct loads from global memory to LDS are supported.
return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" || return ck::get_device_name() == "gfx90a" || ck::get_device_name() == "gfx940" ||
ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942"; ck::get_device_name() == "gfx941" || ck::get_device_name() == "gfx942";
} }
......
...@@ -38,8 +38,6 @@ namespace ck { ...@@ -38,8 +38,6 @@ namespace ck {
* - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64, * - threads in a wavefront must write contiguous data to LDS (when wavefront size is 64,
* they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way * they must write 64 contiguous DWORDs) - `ThreadClusterLengths` must be prepared in such a way
* to guarantee that. * to guarantee that.
*
* For now, only single LDS buffer is supported.
*/ */
template <typename ThreadGroup, template <typename ThreadGroup,
typename BlockSliceLengths, typename BlockSliceLengths,
...@@ -50,8 +48,7 @@ template <typename ThreadGroup, ...@@ -50,8 +48,7 @@ template <typename ThreadGroup,
typename DstDesc, typename DstDesc,
index_t SrcVectorDim, index_t SrcVectorDim,
index_t DstVectorDim, index_t DstVectorDim,
index_t ScalarPerVector, index_t ScalarPerVector>
index_t NumLdsBuffers = 1>
struct ThreadGroupTensorSliceTransfer_DirectLoad struct ThreadGroupTensorSliceTransfer_DirectLoad
{ {
static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension(); static constexpr index_t nDim = remove_reference_t<SrcDesc>::GetNumOfDimension();
...@@ -227,7 +224,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad ...@@ -227,7 +224,7 @@ struct ThreadGroupTensorSliceTransfer_DirectLoad
const bool is_src_valid = const bool is_src_valid =
coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_); coordinate_has_valid_offset_assuming_visible_index_is_valid(src_desc, src_coord_);
src_buf.template CopyTo<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>( src_buf.template DirectCopyToLds<remove_cvref_t<decltype(dst_buf)>, ScalarPerVector>(
dst_buf, src_offset, dst_offset, is_src_valid); dst_buf, src_offset, dst_offset, is_src_valid);
constexpr auto move_on_dim = [&]() constexpr constexpr auto move_on_dim = [&]() constexpr
......
...@@ -571,8 +571,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -571,8 +571,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype(a_block_desc_ak0_m_ak1), decltype(a_block_desc_ak0_m_ak1),
ABlockTransferSrcVectorDim, ABlockTransferSrcVectorDim,
2, 2,
ABlockTransferScalarPerVector, ABlockTransferScalarPerVector>(
NumGemmKPrefetchStage>(
a_grid_desc_ak0_m_ak1, a_grid_desc_ak0_m_ak1,
make_multi_index(0, m_block_data_idx_on_grid, 0), make_multi_index(0, m_block_data_idx_on_grid, 0),
a_block_desc_ak0_m_ak1, a_block_desc_ak0_m_ak1,
...@@ -588,8 +587,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad ...@@ -588,8 +587,7 @@ struct GridwiseGemmMultipleD_Xdl_CShuffle_LdsDirectLoad
decltype(b_block_desc_bk0_n_bk1), decltype(b_block_desc_bk0_n_bk1),
BBlockTransferSrcVectorDim, BBlockTransferSrcVectorDim,
2, 2,
BBlockTransferScalarPerVector, BBlockTransferScalarPerVector>(
NumGemmKPrefetchStage>(
b_grid_desc_bk0_n_bk1, b_grid_desc_bk0_n_bk1,
make_multi_index(0, n_block_data_idx_on_grid, 0), make_multi_index(0, n_block_data_idx_on_grid, 0),
b_block_desc_bk0_n_bk1, b_block_desc_bk0_n_bk1,
......
...@@ -15,6 +15,7 @@ enum struct PipelineVersion ...@@ -15,6 +15,7 @@ enum struct PipelineVersion
{ {
v1, v1,
v2, v2,
// v3 is only used in the Stream-K implementation.
v4, v4,
}; };
......
...@@ -174,8 +174,10 @@ struct DynamicBuffer ...@@ -174,8 +174,10 @@ struct DynamicBuffer
} }
template <typename DstBuffer, index_t NumElemsPerThread> template <typename DstBuffer, index_t NumElemsPerThread>
__host__ __device__ void __host__ __device__ void DirectCopyToLds(DstBuffer& dst_buf,
CopyTo(DstBuffer& dst_buf, index_t src_offset, index_t dst_offset, bool is_valid_element) const index_t src_offset,
index_t dst_offset,
bool is_valid_element) const
{ {
// Copy data from global to LDS memory using direct loads. // Copy data from global to LDS memory using direct loads.
static_assert(GetAddressSpace() == AddressSpaceEnum::Global, static_assert(GetAddressSpace() == AddressSpaceEnum::Global,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment