Commit bdcae547 authored by Tri Dao's avatar Tri Dao
Browse files

[LayerNorm] Don't exit early in the backward pass (fix #781)

parent 36bc29ed
...@@ -452,8 +452,7 @@ def _layer_norm_bwd_kernel( ...@@ -452,8 +452,7 @@ def _layer_norm_bwd_kernel(
# Map the program id to the elements of X, DX, and DY it should compute. # Map the program id to the elements of X, DX, and DY it should compute.
row_block_id = tl.program_id(0) row_block_id = tl.program_id(0)
row_start = row_block_id * rows_per_program row_start = row_block_id * rows_per_program
if row_start >= M: # Do not early exit if row_start >= M, because we need to write DW and DB
return
cols = tl.arange(0, BLOCK_N) cols = tl.arange(0, BLOCK_N)
mask = cols < N mask = cols < N
X += row_start * stride_x_row X += row_start * stride_x_row
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment