Unverified Commit b90a8d3a authored by cyanguwa's avatar cyanguwa Committed by GitHub
Browse files

Fix rng_state issue and minor compiler warning (#395)



fix rng_state issue and minor compiler warning
Signed-off-by: default avatarCharlene Yang <8636796+cyanguwa@users.noreply.github.com>
parent 3b7b7c68
......@@ -181,9 +181,6 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
seqlens.fill_(config.seq_len)
cu_seqlens = torch.zeros(bs + 1, device = inp.device, dtype = torch.int32)
cu_seqlens[1:] = torch.cumsum(seqlens, dim = 0)
op_grad = torch.randn(
config.seq_len, bs, config.num_attention_heads * config.head_dim,
dtype = dtype).cuda()
sigma = 0.02
init_method = init_method_normal(sigma)
......@@ -241,7 +238,8 @@ def _run_transformer_layer(dtype, bs, config, backend, ckpt_attn, bias_type):
checkpoint_core_attention = ckpt_attn,
core_attention_bias_type = bias_type,
core_attention_bias = bias)
op.backward(op_grad)
loss = op.sum()
loss.backward()
return op, inp.grad
......
......@@ -293,8 +293,6 @@ transpose_dbias_kernel_notaligned(const Param param,
}
}
OVec out_trans[nvec_in]; // NOLINT(*)
const bool valid_store = my_place < tile_length &&
warp_id_in_tile * n_iterations + i < tile_height;
transpose_regs_partial_dbias(
in[current_in ^ 1],
out_trans,
......
......@@ -194,7 +194,13 @@ std::vector<at::Tensor> fused_attn_fwd_qkvpacked(
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
}
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
......@@ -497,7 +503,13 @@ std::vector<at::Tensor> fused_attn_fwd_kvpacked(
for (size_t i = 0; i < nvte_aux_tensor_pack.size; ++i) {
auto tensor = reinterpret_cast<transformer_engine::Tensor*>(nvte_aux_tensor_pack.tensors[i]);
// allocate memory for nvte_aux_tensor_pack.tensors
auto output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
at::Tensor output_tensor;
if (nvte_aux_tensor_pack.size >= 2) {
output_tensor = (i < nvte_aux_tensor_pack.size-1)
? allocateSpace(tensor->data.shape, tensor->data.dtype, false) : rng_state;
} else {
output_tensor = allocateSpace(tensor->data.shape, tensor->data.dtype, false);
}
output_tensors.push_back(output_tensor);
tensor->data.dptr = output_tensor.data_ptr();
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment