Unverified Commit ebf30641 authored by zhang's avatar zhang Committed by GitHub
Browse files

Refine handling for q/v sequence length equals zero. (#92)

parent 261330bb
...@@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized { ...@@ -118,12 +118,15 @@ struct Sm100FmhaFwdEpilogueTmaWarpspecialized {
auto cumulative_length_q = get<0>(problem_shape).cumulative_length; auto cumulative_length_q = get<0>(problem_shape).cumulative_length;
if (cumulative_length_q != nullptr) { if (cumulative_length_q != nullptr) {
int max_length_q = get<0>(problem_shape).max_length; int max_length_q = get<0>(problem_shape).max_length;
get<0>(problem_shape_O).max_length = max(1, max_length_q);
// for variable sequence lenght, the batch is in units of row_stride // for variable sequence lenght, the batch is in units of row_stride
get<2,1>(dO) = get<0>(dO); get<2,1>(dO) = get<0>(dO);
get<2,1>(problem_shape_O) = max_length_q * (1 + get<2,1>(problem_shape_O)); get<2,1>(problem_shape_O) = max(1, max_length_q * (1 + get<2,1>(problem_shape_O)));
// offset ptr by the amount we add back in later // offset ptr by the amount we add back in later
ptr_O -= max_length_q * get<0>(dO); ptr_O -= max_length_q * get<0>(dO);
} }
} else {
get<0>(problem_shape_O) = max(1, get<0>(problem_shape_O));
} }
auto tma_store_o = make_tma_copy( auto tma_store_o = make_tma_copy(
......
...@@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized { ...@@ -1155,10 +1155,6 @@ struct Sm100FmhaFwdMainloopTmaWarpspecialized {
float lse = -INFINITY; float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1 #if 1
using ElementOut = typename CollectiveEpilogue::ElementOut; using ElementOut = typename CollectiveEpilogue::ElementOut;
......
...@@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized { ...@@ -112,6 +112,9 @@ struct Sm100FmhaLoadTmaWarpspecialized {
problem_shape_qk = problem_shape; problem_shape_qk = problem_shape;
} }
get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));
get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));
auto params_qk = CollectiveMmaQK::to_underlying_arguments( auto params_qk = CollectiveMmaQK::to_underlying_arguments(
problem_shape_qk, problem_shape_qk,
typename CollectiveMmaQK::Arguments { typename CollectiveMmaQK::Arguments {
......
...@@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized { ...@@ -1162,10 +1162,6 @@ struct Sm100MlaFwdMainloopTmaWarpspecialized {
float lse = -INFINITY; float lse = -INFINITY;
int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp); int thread_idx = threadIdx.x % (4 * NumThreadsPerWarp);
#define DSHOW(x) print(#x ": "); print(x); print("\n")
if (threadIdx.x % 128 == 0 && block0()) {
DSHOW(sO);
}
#if 1 #if 1
using ElementOut = typename CollectiveEpilogue::ElementOut; using ElementOut = typename CollectiveEpilogue::ElementOut;
......
...@@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized { ...@@ -119,6 +119,9 @@ struct Sm100MlaFwdLoadTmaWarpspecialized {
problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));; problem_shape_qk = replace<2>(problem_shape, get<2, 0>(problem_shape) + get<2, 1>(problem_shape));;
} }
get<0>(problem_shape_qk) = max(1, get<0>(problem_shape_qk));
get<1>(problem_shape_qk) = max(1, get<1>(problem_shape_qk));
auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape)); auto problem_shape_pv = replace<1>(select<0,2,1,3>(problem_shape_qk), get<2, 0>(problem_shape));
auto params_qk = CollectiveMmaQK::to_underlying_arguments( auto params_qk = CollectiveMmaQK::to_underlying_arguments(
......
...@@ -208,6 +208,11 @@ public: ...@@ -208,6 +208,11 @@ public:
dim3 const block = Kernel::get_block_shape(); dim3 const block = Kernel::get_block_shape();
dim3 const grid = get_grid_shape(params); dim3 const grid = get_grid_shape(params);
// No need to launch the kernel
if(grid.x == 0 || grid.y == 0 || grid.z == 0) {
return Status::kSuccess;
}
// configure smem size and carveout // configure smem size and carveout
int smem_size = Kernel::SharedStorageSize; int smem_size = Kernel::SharedStorageSize;
......
...@@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler { ...@@ -160,7 +160,7 @@ struct CausalPersistentTileScheduler {
return Params { return Params {
num_blocks, num_blocks,
{ size<3,0>(problem_size) }, { num_m_blocks}, { size<3,1>(problem_size) }, { size<3,0>(problem_size) }, { max(1, num_m_blocks) }, { size<3,1>(problem_size) },
hw_info hw_info
}; };
} }
......
...@@ -123,7 +123,7 @@ struct PersistentTileScheduler { ...@@ -123,7 +123,7 @@ struct PersistentTileScheduler {
return Params { return Params {
num_blocks, num_blocks,
{ num_m_blocks}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) }, { max(1, num_m_blocks)}, { size<3,0>(problem_size) }, { size<3,1>(problem_size) },
hw_info hw_info
}; };
} }
......
...@@ -29,6 +29,9 @@ def get_attn_bias(s_q, s_k, causal, window): ...@@ -29,6 +29,9 @@ def get_attn_bias(s_q, s_k, causal, window):
def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None: def assert_close(x: torch.Tensor, y: torch.Tensor, name: str) -> None:
close_tensor = torch.isclose(x.to(torch.float32), y.to(torch.float32), rtol=1e-5, atol=1e-5)
if close_tensor.all():
return
x, y = x.double(), y.double() x, y = x.double(), y.double()
RMSE = ((x - y) * (x - y)).mean().sqrt().item() RMSE = ((x - y) * (x - y)).mean().sqrt().item()
cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12) cos_diff = 1 - 2 * (x * y).sum().item() / max((x * x + y * y).sum().item(), 1e-12)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment