Commit a4f148b6 authored by Tri Dao's avatar Tri Dao
Browse files

Fix masking of bwd when seqlen is not divisible by 128

parent 184b992d
......@@ -415,7 +415,7 @@ inline __device__ void convert_dKV(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, bool Is_first, bool Is_last, bool Seq_parallel=false, typename Params>
inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const int bidb, const int bidh, const int n_block) {
using Element = typename Kernel_traits::Element;
......@@ -436,7 +436,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
constexpr int AtomLayoutMS = Kernel_traits::AtomLayoutMSdP;
constexpr bool Double_buffer = !Kernel_traits::No_double_buffer;
const BlockInfo</*Varlen=*/!Is_even_M> binfo(params, bidb);
const BlockInfo</*Varlen=*/!Is_even_MN> binfo(params, bidb);
if (n_block * kBlockN >= binfo.actual_seqlen_k || binfo.actual_seqlen_q == 0) return;
int m_block_max = cute::ceil_div(binfo.actual_seqlen_q, kBlockM);
......@@ -668,10 +668,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
return;
......@@ -687,7 +687,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
if (Kernel_traits::Is_V_in_regs) {
// Clear the smem tiles to account for predicated off loads
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::cp_async_fence();
......@@ -697,18 +697,18 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tdOrO = make_fragment_like(tdOgO);
if (!Is_first) {
// Clear the smem tiles to account for predicated off loads
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOsdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
} else {
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgdO, tdOrdO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_dO, tdOgO, tdOrO, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
}
flash::copy<Is_even_M, Is_even_K, /*Clear_OOB_MN=*/true>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tQgQ, tQsQ, tQcQ, tQpQ, binfo.actual_seqlen_q - m_block * kBlockM
);
......@@ -722,7 +722,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
for (int mi = 0; mi < size(lse); ++mi) {
// Using uint32_t row makes it 10us slower on d=128, not sure why.
const int row = get<0>(taccScS_row(mi));
lse(mi) = Is_even_M || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
lse(mi) = Is_even_MN || row < binfo.actual_seqlen_q - m_block * kBlockM ? gLSE(row) : 0;
}
// Tensor tKrK = make_fragment_like(tKsK);
......@@ -730,11 +730,11 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// copy(gmem_thr_copy_QKV, tKgK, tKrK);
// // if (cute::thread(1, 0)) { print(tKrK); }
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tKgK, tKsK, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
if (!Kernel_traits::Is_V_in_regs) {
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/true>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/true>(
gmem_thr_copy_QKV, tVgV, tVsV, tKVcKV, tKVpKV, binfo.actual_seqlen_k - n_block * kBlockN
);
}
......@@ -783,15 +783,26 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
// Reshape acc_s from (MMA=4, MMA_N, MMA_N) to (col=(2, MMA_N), row=(2, MMA_N))
Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
// if (cute::thread(32, 0)) { print(scores); }
// We don't need to mask out the elements beyond actual_seqlen_k, because acc_s would
// be some finite value for those indices. In the end when we multiply with K to get dQ,
// the corresponding values of K would be 0, so the result would still be correct.
// Putting this causal masking right after acc_s is *much* slower for some reason.
if (Is_causal && m_block * kBlockM < (n_block + 1) * kBlockN) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
// TD [2023-07-29]: I was thinking that we don't need to mask out the elements beyond
// actual_seqlen_k, because acc_s would be some finite value for those indices.
// In the end when we multiply with K to get dQ, the corresponding values of K would be 0,
// so the result would still be correct.
// However, it's possible that the values in acc_s are so large that they overflow
// when we multiply with dP and convert to fp16, resulting in Inf in dS and NaNs in dQ.
// So we need to mask out the elements beyond actual_seqlen_k.
if (!Is_causal) {
if (!Is_even_MN && (n_block + 1) * kBlockN >= binfo.actual_seqlen_k) {
flash::apply_mask(scores, binfo.actual_seqlen_k,
n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16);
}
} else {
// Putting this causal masking right after acc_s is *much* slower for some reason.
if (m_block * kBlockM < (n_block + 1) * kBlockN) {
flash::apply_mask_causal(scores, n_block * kBlockN + (tidx / 32 / AtomLayoutMS) * MMA_N_SdP * 16,
binfo.actual_seqlen_k, m_block * kBlockM + get<0>(taccScS_row(0)),
// binfo.actual_seqlen_k, m_block * kBlockM + (tidx / 32) % AtomLayoutMS * 16 + (tidx % 32) / 4,
AtomLayoutMS * 16);
}
}
// if (cute::thread(32, 0)) { print(scores); }
// Compute the exponential value.
......@@ -978,7 +989,7 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
Tensor tdQcdQ = gmem_thr_copy_dQ.partition_D(cdQ);
#pragma unroll
for (int m = 0; m < size<1>(tdQgdQ); ++m) {
if (Is_even_M || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
if (Is_even_MN || get<0>(tdQcdQ(0, m, 0)) < binfo.actual_seqlen_q - m_block * kBlockM) {
copy(gmem_thr_copy_dQ, tdQrdQ(_, m, _), tdQgdQ(_, m, _));
}
}
......@@ -1044,10 +1055,10 @@ inline __device__ void compute_dq_dk_dv_1colblock(const Params &params, const in
#pragma unroll
for (int k = 0; k < size(tdKVpdKV); ++k) { tdKVpdKV(k) = get<1>(tdKVcdKV(0, 0, k)) < params.d; }
// Clear_OOB_K must be false since we don't want to write zeros to gmem
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdKrdK, tdKgdK, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
flash::copy</*Is_even_MN=*/false, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
flash::copy<Is_even_MN, Is_even_K, /*Clear_OOB_MN=*/false, /*Clear_OOB_K=*/false>(
gmem_thr_copy_dKV, tdVrdV, tdVgdV, tdKVcdKV, tdKVpdKV, binfo.actual_seqlen_k - n_block * kBlockN
);
......@@ -1487,7 +1498,7 @@ inline __device__ void compute_dq_dk_dv(const Params &params) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K, typename Params>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K, typename Params>
inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
const int n_block = blockIdx.x;
......@@ -1496,7 +1507,7 @@ inline __device__ void compute_dq_dk_dv_seqk_parallel(const Params &params) {
// The block index for the head.
const int bidh = blockIdx.z;
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
compute_dq_dk_dv_1colblock<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K, false, false, /*Seq_parallel=*/true>(params, bidb, bidh, n_block);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
......
......@@ -23,9 +23,9 @@ __global__ void flash_bwd_dq_dk_dv_loop_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_M, bool Is_even_K>
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_MN, bool Is_even_K>
__global__ void flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel(Flash_bwd_params params) {
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_M, Is_even_K>(params);
flash::compute_dq_dk_dv_seqk_parallel<Kernel_traits, Is_dropout, Is_causal, Is_even_MN, Is_even_K>(params);
}
template<typename Kernel_traits, bool Is_dropout, bool Is_causal, bool Is_even_N, bool Is_even_K>
......@@ -53,17 +53,17 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params &params, cudaStream_t stream,
flash_bwd_dot_do_o_kernel<true, Kernel_traits><<<grid_m, Kernel_traits::kNThreads, 0, stream>>>(params);
C10_CUDA_KERNEL_LAUNCH_CHECK();
// We also use is_even_M to set Unpadded in the BlockInfo constructor, so we need to check
// for cu_seqlens_q as well.
const bool is_even_M = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0;
// We want to specialize to is_even_MN and not just is_even_M, since in the case where N is not
// a multiple of kBlockN, we'll need to apply mask in the loop.
const bool is_even_MN = params.cu_seqlens_q == nullptr && params.cu_seqlens_k == nullptr && params.seqlen_q % Kernel_traits::kBlockM == 0 && params.seqlen_k % Kernel_traits::kBlockN == 0;
const bool is_even_K = params.d == Kernel_traits::kHeadDim;
constexpr int smem_size_dq_dk_dv = Kernel_traits::kSmemSize1colblock;
// printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv);
BOOL_SWITCH(params.is_causal, IsCausalConst, [&] {
BOOL_SWITCH(is_even_M, IsEvenMConst, [&] {
BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMConst, true>;
auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenMNConst, true>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
......@@ -102,7 +102,7 @@ void run_flash_bwd_seqq_parallel(Flash_bwd_params &params, cudaStream_t stream,
BOOL_SWITCH(is_even_N, IsEvenNConst, [&] {
BOOL_SWITCH(is_even_K, IsEvenKConst, [&] {
auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, IsEvenNConst, IsEvenKConst>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, Is_dropout, IsCausalConst, true, true>;
// auto kernel = &flash_bwd_dq_dk_dv_loop_seqq_parallel_kernel<Kernel_traits, false, false, IsEvenNConst, IsEvenKConst>;
if (smem_size_dq_dk_dv >= 48 * 1024) {
C10_CUDA_CHECK(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv));
......
......@@ -117,15 +117,18 @@ inline __device__ void max_scale_exp2_sum(Tensor<Engine0, Layout0> &tensor, Tens
}
template <typename Engine, typename Layout>
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k) {
inline __device__ void apply_mask(Tensor<Engine, Layout> &tensor, const uint32_t max_seqlen_k,
const uint32_t col_idx_offset_ = 0) {
// tensor has shape (ncol=(2, MMA_M), nrow=(2, MMA_N))
static_assert(Layout::rank == 2, "Only support 2D Tensor");
const uint32_t lane_id = threadIdx.x % 32;
const uint32_t col_idx_offset = col_idx_offset_ + (lane_id % 4) * 2;
#pragma unroll
for (int nj = 0; nj < size<1, 1>(tensor); ++nj) {
const uint32_t col_idx_base = col_idx_offset + nj * 8;
#pragma unroll
for (int j = 0; j < size<1, 0>(tensor); ++j) {
const uint32_t col_idx = nj * 8 + j + (lane_id % 4) * 2;
const uint32_t col_idx = col_idx_base + j;
if (col_idx >= max_seqlen_k) {
// Without the "make_coord" we get wrong results
#pragma unroll
......
......@@ -825,3 +825,97 @@ def test_flash_attn_race_condition(seqlen, d, dropout_p, causal, dtype):
assert torch.equal(dqkv[:, :, 0], dqkv0[:, :, 0])
assert torch.equal(dqkv[:, :, 1], dqkv0[:, :, 1])
assert torch.equal(dqkv[:, :, 2], dqkv0[:, :, 2])
@pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [16, 32, 64])
# @pytest.mark.parametrize('d', [16])
@pytest.mark.parametrize('seqlen', [1, 2, 5, 17, 128])
# @pytest.mark.parametrize('seqlen', [2])
def test_flash_attn_bwd_overflow(seqlen, d, causal, dtype):
""" We previously had a bug where not masking elements beyond seqlen_k caused NaN in dQ,
in the case where seqlen % 128 != 0.
"""
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 2
nheads = 5
q = torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 5
k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda") * 3 for _ in range(2)]
q.requires_grad_(True)
k.requires_grad_(True)
v.requires_grad_(True)
out = flash_attn_func(q, k, v, causal=causal)
g = torch.randn_like(out)
out.backward(g)
q_pt = q.detach().clone().requires_grad_(True)
k_pt = k.detach().clone().requires_grad_(True)
v_pt = v.detach().clone().requires_grad_(True)
out_pt, _ = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
out_pt.backward(g)
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 5 * (q_pt.grad - q_ref.grad).abs().max().item() + 1e-3
assert (k.grad - k_ref.grad).abs().max().item() <= 5 * (k_pt.grad - k_ref.grad).abs().max().item() + 1e-3
assert (v.grad - v_ref.grad).abs().max().item() <= 5 * (v_pt.grad - v_ref.grad).abs().max().item() + 1e-3
@pytest.mark.parametrize('dtype', ([torch.float16] if is_sm75 else [torch.float16, torch.bfloat16]))
# @pytest.mark.parametrize('dtype', [torch.bfloat16])
@pytest.mark.parametrize('causal', [False, True])
# @pytest.mark.parametrize('causal', [False])
@pytest.mark.parametrize('d', [64, 128])
# @pytest.mark.parametrize('d', [64])
@pytest.mark.parametrize('seqlen', [97, 128, 200, 256])
# @pytest.mark.parametrize('seqlen', [128])
def test_flash_attn_bwd_transpose(seqlen, d, causal, dtype):
""" We previously had a bug where we were using the wrong strides of dout, which shows up
when dout is not contiguous.
"""
device = 'cuda'
# set seed
torch.random.manual_seed(0)
batch_size = 5
nheads = 2
q, k, v = [torch.randn([batch_size, seqlen, nheads, d], dtype=dtype, device="cuda",
requires_grad=True)
for _ in range(3)]
out = rearrange(flash_attn_func(q, k, v, causal=causal), "b s ... -> s b ...")
# So g is not contiguous
g = torch.randn(seqlen, 2 * batch_size, nheads, d, dtype=dtype, device="cuda")[:, ::2]
out.backward(g)
q_pt = q.detach().clone().requires_grad_(True)
k_pt = k.detach().clone().requires_grad_(True)
v_pt = v.detach().clone().requires_grad_(True)
out_pt, attn_pt = attention_ref(q_pt, k_pt, v_pt, causal=causal, upcast=False, reorder_ops=True)
out_pt = rearrange(out_pt, "b s ... -> s b ...")
out_pt.backward(g)
q_ref = q.detach().clone().requires_grad_(True)
k_ref = k.detach().clone().requires_grad_(True)
v_ref = v.detach().clone().requires_grad_(True)
out_ref, attn_ref = attention_ref(q_ref, k_ref, v_ref, causal=causal)
out_ref = rearrange(out_ref, "b s ... -> s b ...")
out_ref.backward(g)
print(f'dQ max diff: {(q.grad - q_ref.grad).abs().max().item()}')
print(f'dK max diff: {(k.grad - k_ref.grad).abs().max().item()}')
print(f'dV max diff: {(v.grad - v_ref.grad).abs().max().item()}')
print(f'dQ Pytorch max diff: {(q_pt.grad - q_ref.grad).abs().max().item()}')
print(f'dK Pytorch max diff: {(k_pt.grad - k_ref.grad).abs().max().item()}')
print(f'dV Pytorch max diff: {(v_pt.grad - v_ref.grad).abs().max().item()}')
assert (out - out_ref).abs().max().item() <= 2 * (out_pt - out_ref).abs().max().item()
assert (q.grad - q_ref.grad).abs().max().item() <= 2 * (q_pt.grad - q_ref.grad).abs().max().item()
assert (k.grad - k_ref.grad).abs().max().item() <= 2 * (k_pt.grad - k_ref.grad).abs().max().item()
assert (v.grad - v_ref.grad).abs().max().item() <= 2 * (v_pt.grad - v_ref.grad).abs().max().item()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment