Unverified Commit 6efeb743 authored by Jiaxing Ding's avatar Jiaxing Ding Committed by GitHub
Browse files

[AMD] fix bf16x2 dtype codegen (#847)

parent e7e38355
......@@ -480,7 +480,7 @@ void CodeGenTileLangHIP::PrintVecElemLoad(const std::string &vec, DataType t,
os << "((half2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.is_bfloat16()) {
os << "((nv_bfloat162*)(&(" << vec << "." << access[i / 2] << ")))->"
os << "((bfloat16x2*)(&(" << vec << "." << access[i / 2] << ")))->"
<< access[i % 2];
} else if (t.lanes() > 4 && t.lanes() <= 8) {
std::string type_name;
......
......@@ -67,7 +67,7 @@ using half_t = float16_t;
using bfloat16_t = hip_bfloat16;
struct bfloat16x2 {
bfloat16_t data[2];
bfloat16_t x, y;
};
struct bfloat16x4 {
......
......@@ -56,6 +56,7 @@ def tl_matmul(
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
C_shared_shape = (
block_M // micro_size_x,
block_N // micro_size_y,
micro_size_x,
micro_size_y,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment