Commit d6262fde authored by Chao Liu's avatar Chao Liu
Browse files

extract channel bit

parent 9e0d6146
......@@ -408,6 +408,10 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
using ACoord = typename TensorCoordinate<AGlobalDesc>::type;
using BCoord = typename TensorCoordinate<BGlobalDesc>::type;
#define CHANNEL_BITS 5
#define NUM_CHANNELS (1 << CHANNEL_BITS)
#define CHANNEL_SHIFT 8
#define CHANNEL_MASK ((NUM_CHANNELS - 1) << CHANNEL_SHIFT)
for(index_t m_block_work_id = 0; m_block_work_id < MBlockWork; ++m_block_work_id)
{
......@@ -419,9 +423,13 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
afile.open("a_mblock_" + std::to_string(m_block_work_id) + "_nblock_" + std::to_string(n_block_work_id) + ".csv", std::fstream::out);
afile << "kblock, offset" << std::endl;
afile << "kblock, channel" << std::endl;
#if 0
for(index_t k_block_work_id = 0; k_block_work_id < KBlockWork; ++k_block_work_id)
#else
index_t k_block_work_id = 0;
#endif
{
for(index_t k = k_block_work_id * KPerBlock ; k < (k_block_work_id + 1) * KPerBlock; ++k)
{
......@@ -431,7 +439,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
if(a_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
afile << k_block_work_id * 100 << "," << a_coord.GetOffset() << std::endl;
uint32_t offset_u = static_cast<uint32_t>(a_coord.GetOffset()) * sizeof(Float);
uint32_t channel = (offset_u & CHANNEL_MASK) >> CHANNEL_SHIFT;
afile << k_block_work_id << "," << channel << std::endl;
}
}
}
......@@ -448,7 +460,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
bfile << "kblock, offset" << std::endl;
#if 0
for(index_t k_block_work_id = 0; k_block_work_id < KBlockWork; ++k_block_work_id)
#else
index_t k_block_work_id = 0;
#endif
{
for(index_t k = k_block_work_id * KPerBlock ; k < (k_block_work_id + 1) * KPerBlock; ++k)
{
......@@ -458,7 +474,11 @@ struct GridwiseGemmTransposedANormalBNormalC_v1
if(b_coord.IsOffsetValidAssumingUpperIndexIsValid())
{
bfile << k_block_work_id * 100<< "," << b_coord.GetOffset() << std::endl;
uint32_t offset_u = static_cast<uint32_t>(b_coord.GetOffset()) * sizeof(Float);
uint32_t channel = (offset_u & CHANNEL_MASK) >> CHANNEL_SHIFT;
bfile << k_block_work_id << "," << channel << std::endl;
}
}
}
......
......@@ -328,13 +328,13 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<0, 0>;
using RightPads = Sequence<0, 0>;
#elif 0
#elif 1
// 1x1, 14x14
constexpr index_t N = 128;
constexpr index_t C = 128;
constexpr index_t N = 64;
constexpr index_t C = 1024;
constexpr index_t HI = 14;
constexpr index_t WI = 14;
constexpr index_t K = 128;
constexpr index_t K = 2048;
constexpr index_t Y = 1;
constexpr index_t X = 1;
......@@ -373,7 +373,7 @@ int main(int argc, char* argv[])
using LeftPads = Sequence<1, 1>;
using RightPads = Sequence<1, 1>;
#elif 1
#elif 0
// 3x3, 14x14
constexpr index_t N = 128;
constexpr index_t C = 128;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment