Commit fd206995 authored by Adam Osewski's avatar Adam Osewski
Browse files

Fix compilation and reviewers comments.

parent a0773ad8
......@@ -36,7 +36,7 @@ enum struct tile_distribution_pattern
block_raked,
};
struct TileDistributionEcodingPattern
struct TileDistributionEncodingPattern
{
};
......@@ -57,7 +57,7 @@ template <index_t BlockSize,
index_t XPerTile,
index_t VecSize,
tile_distribution_pattern DistributionPattern>
struct TileDistributionEncodingPattern2D : public TileDistributionEcodingPattern
struct TileDistributionEncodingPattern2D : public TileDistributionEncodingPattern
{
};
......@@ -68,7 +68,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
XPerTile,
VecSize,
tile_distribution_pattern::thread_raked>
: public TileDistributionEcodingPattern
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
......@@ -120,7 +120,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
XPerTile,
VecSize,
tile_distribution_pattern::warp_raked>
: public TileDistributionEcodingPattern
: public TileDistributionEncodingPattern
{
static_assert(XPerTile % VecSize == 0, "XPerTile must be a multiple of VecSize!");
......@@ -168,7 +168,7 @@ struct TileDistributionEncodingPattern2D<BlockSize,
XPerTile,
VecSize,
tile_distribution_pattern::block_raked>
: public TileDistributionEcodingPattern
: public TileDistributionEncodingPattern
{
// TODO: make pattern where below condition does not need to hold - GGemmMultiDSplitk!
......
......@@ -26,6 +26,9 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
{
constexpr auto I0 = number<0>{};
static_assert(std::is_same_v<typename InTensor::DataType, typename OutTensor::DataType>,
"Data type for InTensor and OutTensor must be the same!");
using DataType = typename InTensor::DataType;
constexpr auto y_in_desc = InTensor::get_tile_distribution().get_ys_to_d_descriptor();
......@@ -65,9 +68,7 @@ CK_TILE_DEVICE void transpose_tile2d_impl_in_thread(OutTensor& out_tensor,
return y_dim_out_to_in_;
}();
//
constexpr index_t NDimY = InTensor::get_tile_distribution().get_num_of_dimension_y();
constexpr auto y_lengths = to_sequence(y_in_desc.get_lengths());
// input and output vector dim in the order of input Y dims
......
......@@ -4,6 +4,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment