Commit 129e58ae authored by Adam Osewski's avatar Adam Osewski
Browse files

Merge remote-tracking branch 'origin/develop' into aosewski/ggemm_multi_d2

parents 9bebfd42 cb0645be
import functools
import os
@functools.lru_cache(None)
def library_path():
return os.path.join(os.path.dirname(__file__), 'library')
...@@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME) ...@@ -40,6 +40,13 @@ function(add_test_executable TEST_NAME)
endif() endif()
endforeach() endforeach()
endif() endif()
if(INSTANCES_ONLY)
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(TEST_TARGETS ${GPU_TARGETS})
endif()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ") message("removing dl test ${source} ")
...@@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME) ...@@ -47,20 +54,27 @@ function(add_test_executable TEST_NAME)
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
message("removing xdl test ${source} ") message("removing xdl test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
message("removing wmma test ${source} ") message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} )
target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt) target_link_libraries(${TEST_NAME} PRIVATE getopt::getopt)
add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>) add_test(NAME ${TEST_NAME} COMMAND $<TARGET_FILE:${TEST_NAME}>)
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
...@@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME) ...@@ -105,6 +119,13 @@ function(add_gtest_executable TEST_NAME)
endif() endif()
endforeach() endforeach()
endif() endif()
if(INSTANCES_ONLY)
set(TEST_TARGETS ${DEFAULT_GPU_TARGETS})
else()
set(TEST_TARGETS ${GPU_TARGETS})
endif()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl") if(NOT DEFINED DL_KERNELS AND source MATCHES "_dl")
message("removing dl test ${source} ") message("removing dl test ${source} ")
...@@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME) ...@@ -112,20 +133,27 @@ function(add_gtest_executable TEST_NAME)
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx9" AND source MATCHES "xdl") if(NOT TEST_TARGETS MATCHES "gfx9" AND source MATCHES "xdl")
message("removing xdl test ${source} ") message("removing xdl test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
foreach(source IN LISTS ARGN) foreach(source IN LISTS ARGN)
if(NOT GPU_TARGETS MATCHES "gfx11" AND source MATCHES "wmma") if(NOT TEST_TARGETS MATCHES "gfx11" AND source MATCHES "wmma")
message("removing wmma test ${source} ") message("removing wmma test ${source} ")
list(REMOVE_ITEM ARGN "${source}") list(REMOVE_ITEM ARGN "${source}")
endif() endif()
endforeach() endforeach()
#only continue if there are some source files left on the list #only continue if there are some source files left on the list
if(ARGN) if(ARGN)
if(ARGN MATCHES "_xdl")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103)
elseif(ARGN MATCHES "_wmma")
list(REMOVE_ITEM TEST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030)
endif()
set_source_files_properties(${ARGN} PROPERTIES LANGUAGE HIP)
add_executable(${TEST_NAME} ${ARGN}) add_executable(${TEST_NAME} ${ARGN})
set_property(TARGET ${TEST_NAME} PROPERTY HIP_ARCHITECTURES ${TEST_TARGETS} )
add_dependencies(tests ${TEST_NAME}) add_dependencies(tests ${TEST_NAME})
add_dependencies(check ${TEST_NAME}) add_dependencies(check ${TEST_NAME})
......
...@@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -32,19 +32,8 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
std::vector<ck::utils::conv::ConvParam> conv_params; std::vector<ck::utils::conv::ConvParam> conv_params;
std::vector<ck::index_t> split_ks{1, 2}; std::vector<ck::index_t> split_ks{1, 2};
bool skip_case(const ck::utils::conv::ConvParam& params, const ck::index_t split_k) bool skip_case(const ck::index_t split_k)
{ {
// Odd K or C values are supported only by DL and WMMA
// kernels (only applies to fp16)
// DL and WMMA kernels currently support only `split_k=1`
if constexpr(std::is_same_v<InDataType, ck::half_t>)
{
if(split_k != 1 && (params.K_ % 2 != 0 || params.C_ % 2 != 0))
{
return true;
}
}
// 1d NWGC is only supported by DL kernel // 1d NWGC is only supported by DL kernel
// DL kernel is only supported for split_k=1 // DL kernel is only supported for split_k=1
if constexpr(std::is_same_v<InLayout, NWGC> && std::is_same_v<OutLayout, NWGK>) if constexpr(std::is_same_v<InLayout, NWGC> && std::is_same_v<OutLayout, NWGK>)
...@@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test ...@@ -100,7 +89,7 @@ class TestGroupedConvndBwdWeight : public ::testing::Test
{ {
for(auto& param : conv_params) for(auto& param : conv_params)
{ {
if(!skip_case(param, split_k)) if(!skip_case(split_k))
{ {
pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{}, pass = pass && ck::profiler::profile_grouped_conv_bwd_weight_impl<NDimSpatial{},
InLayout, InLayout,
...@@ -189,6 +178,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D) ...@@ -189,6 +178,8 @@ TYPED_TEST(TestGroupedConvndBwdWeight2d, Test2D)
this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 32, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 64, 3, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}}); this->conv_params.push_back({2, 1, 1, 1, 1, {3, 3}, {32, 32}, {1, 1}, {1, 1}, {1, 1}, {1, 1}});
this->conv_params.push_back(
{2, 16, 16, 1, 1, {3, 3}, {28, 28}, {2, 2}, {1, 1}, {1, 1}, {1, 1}});
this->Run(); this->Run();
} }
...@@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D) ...@@ -207,5 +198,7 @@ TYPED_TEST(TestGroupedConvndBwdWeight3d, Test3D)
{3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 64, 3, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back( this->conv_params.push_back(
{3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}}); {3, 1, 1, 1, 1, {3, 3, 3}, {32, 32, 32}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->conv_params.push_back(
{3, 16, 16, 1, 1, {3, 3, 3}, {28, 28, 28}, {2, 2, 2}, {1, 1, 1}, {1, 1, 1}, {1, 1, 1}});
this->Run(); this->Run();
} }
...@@ -6,6 +6,12 @@ if(result EQUAL 0) ...@@ -6,6 +6,12 @@ if(result EQUAL 0)
add_dependencies(test_grouped_gemm test_grouped_gemm_splitk) add_dependencies(test_grouped_gemm test_grouped_gemm_splitk)
endif() endif()
add_gtest_executable(test_grouped_gemm_two_stage_splitk test_grouped_gemm_two_stage_multiple_d_splitk_xdl.cpp)
if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_two_stage_splitk PRIVATE utility device_grouped_gemm_instance)
add_dependencies(test_grouped_gemm test_grouped_gemm_two_stage_splitk)
endif()
add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp) add_gtest_executable(test_grouped_gemm_interface test_grouped_gemm_interface_xdl.cpp)
if(result EQUAL 0) if(result EQUAL 0)
target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance) target_link_libraries(test_grouped_gemm_interface PRIVATE utility device_grouped_gemm_instance)
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include <tuple>
#include <vector>
#include "ck/tensor_operation/gpu/device/tensor_layout.hpp"
#include "ck/utility/data_type.hpp"
#include "gtest/gtest.h"
#include "test_grouped_gemm_util.hpp"
using F16 = ck::half_t;
using BF16 = ck::bhalf_t;
using I8 = int8_t;
using Row = ck::tensor_layout::gemm::RowMajor;
using Col = ck::tensor_layout::gemm::ColumnMajor;
using RRR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16 = ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
using RRR_F16_F16_F16_LargeK =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, F16, F16, F16>>;
using RCR_F16_F16_F16_LargeK =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, F16, F16, F16>>;
using RRR_BF16_BF16_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, BF16, BF16>>;
using RCR_BF16_BF16_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, BF16, BF16>>;
using RRR_BF16_I8_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Row, Row, BF16, I8, BF16>>;
using RCR_BF16_I8_BF16 =
ck::test::TestGroupedGemmTwoStage<std::tuple<Row, Col, Row, BF16, I8, BF16>>;
const std::vector<int> KBATCH{1, 2, 3, 5, 8};
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN,
RRR_F16_F16_F16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK,
RCR_F16_F16_F16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16,
RRR_BF16_BF16_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16,
RCR_BF16_BF16_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_KN_BF16_INT8,
RRR_BF16_I8_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_MK_NK_BF16_INT8,
RCR_BF16_I8_BF16,
testing::ValuesIn(KBATCH));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_KN,
RRR_F16_F16_F16_LargeK,
testing::Values(32, 64));
INSTANTIATE_TEST_SUITE_P(TestGroupedGemmTwoStage_splitk_LargeK_MK_NK,
RCR_F16_F16_F16_LargeK,
testing::Values(32, 64));
#include "test_grouped_gemm_ut_cases.inc"
#include "test_grouped_gemm_two_stage_ut_cases.inc"
#pragma once
TEST_P(RRR_BF16_BF16_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_BF16_BF16_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RRR_BF16_I8_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), N);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
TEST_P(RCR_BF16_I8_BF16, MNKPadded)
{
const std::vector<int> Ms{127, 150, 188, 210};
constexpr int N = 136;
constexpr int K = 280;
const std::vector<int> Ns(Ms.size(), N);
const std::vector<int> Ks(Ms.size(), K);
const std::vector<int> StrideAs(Ms.size(), K);
const std::vector<int> StrideBs(Ms.size(), K);
const std::vector<int> StrideCs(Ms.size(), N);
this->Run(Ms, Ns, Ks, StrideAs, StrideBs, StrideCs, this->GetParam());
}
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck/utility/tuple.hpp" #include "ck/utility/tuple.hpp"
#include "ck/utility/number.hpp" #include "ck/utility/number.hpp"
#include "profiler/profile_grouped_gemm_impl.hpp" #include "profiler/profile_grouped_gemm_impl.hpp"
#include "profiler/profile_grouped_gemm_two_stage_impl.hpp"
namespace ck { namespace ck {
namespace test { namespace test {
...@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int> ...@@ -90,6 +91,58 @@ class TestGroupedGemm : public testing::TestWithParam<int>
} }
}; };
template <typename Tuple>
class TestGroupedGemmTwoStage : public testing::TestWithParam<int>
{
protected:
using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>;
using ELayout = std::tuple_element_t<2, Tuple>;
using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>;
using EDataType = std::tuple_element_t<5, Tuple>;
public:
static constexpr bool verify_ = true;
static constexpr int init_method_ = 1; // decimal value initialization
static constexpr bool log_ = false;
static constexpr bool bench_ = false; // measure kernel performance
void SetUp() override {}
void Run(const std::vector<int>& Ms,
const std::vector<int>& Ns,
const std::vector<int>& Ks,
const std::vector<int>& StrideAs,
const std::vector<int>& StrideBs,
const std::vector<int>& StrideCs,
int kbatch = 1,
int n_warmup = 1,
int n_iter = 10)
{
bool pass = ck::profiler::profile_grouped_gemm_two_stage_impl<ADataType,
BDataType,
EDataType,
float,
ALayout,
BLayout,
ELayout>(verify_,
init_method_,
log_,
bench_,
Ms,
Ns,
Ks,
StrideAs,
StrideBs,
StrideCs,
kbatch,
n_warmup,
n_iter);
EXPECT_TRUE(pass);
}
};
template <typename ALayout, template <typename ALayout,
typename BLayout, typename BLayout,
typename ELayout, typename ELayout,
......
...@@ -131,74 +131,74 @@ int main() ...@@ -131,74 +131,74 @@ int main()
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<true, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<true, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<true, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::VERTICAL, {0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5,
0, 1, 2, 3, 4, 5}); 0, 1, 2, 3, 4, 5});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3, -4, -5,
1, 0, 1, 2, 3, 4, -1, 0, -1, -2, -3, -4,
2, 1, 0, 1, 2, 3, -2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2}); -3, -2, -1, 0, -1, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0, -3, -2, -1, 0,
4, 3, 2, 1, -4, -3, -2, -1,
5, 4, 3, 2}); -5, -4, -3, -2});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_TOP_LEFT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 1, 0, 1, 2, 3, rtn &= test_alibi_traverse_with_slope<false, dtype>(4, 6, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -1, 0, -1, -2, -3,
3, 2, 1, 0, 1, 2, -3, -2, -1, 0, -1, -2,
4, 3, 2, 1, 0, 1, -4, -3, -2, -1, 0, -1,
5, 4, 3, 2, 1, 0}); -5, -4, -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {2, 3, 4, 5, rtn &= test_alibi_traverse_with_slope<false, dtype>(6, 4, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {-2, -3, -4, -5,
1, 2, 3, 4, -1, -2, -3, -4,
0, 1, 2, 3, 0, -1, -2, -3,
1, 0, 1, 2, -1, 0, -1, -2,
2, 1, 0, 1, -2, -1, 0, -1,
3, 2, 1, 0}); -3, -2, -1, 0});
rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, {0, 1, 2, rtn &= test_alibi_traverse_with_slope<false, dtype>(3, 3, slope, ck_tile::AlibiMode::FROM_BOTTOM_RIGHT, { 0, -1, -2,
1, 0, 1, -1, 0, -1,
2, 1, 0}); -2, -1, 0});
rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625}); rtn &= test_alibi_slope_generation<float>(8, {0.5, 0.25, 0.125, 0.0625, 0.03125, 0.015625, 0.0078125, 0.00390625});
rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692, rtn &= test_alibi_slope_generation<float>(16, {0.7071067811865476, 0.5, 0.35355339059327384, 0.25000000000000006, 0.17677669529663692,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment