Unverified Commit 1fcc5f7a authored by Conglong Li's avatar Conglong Li Committed by GitHub
Browse files

Fix transformer kernel CUDA illegal memory access error (#765)

parent 68e138b6
...@@ -877,7 +877,11 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id, ...@@ -877,7 +877,11 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
seq_len = g_output.size(1); seq_len = g_output.size(1);
layer->SetSeqLength(seq_len); layer->SetSeqLength(seq_len);
} }
auto options = torch::TensorOptions()
.dtype(g_output.options().dtype())
.layout(torch::kStrided)
.device(torch::kCUDA)
.requires_grad(true);
auto workspace = torch::empty({get_workspace_size<T>(bsz, auto workspace = torch::empty({get_workspace_size<T>(bsz,
seq_len, seq_len,
layer->GetHiddenSize(), layer->GetHiddenSize(),
...@@ -885,7 +889,7 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id, ...@@ -885,7 +889,7 @@ std::vector<torch::Tensor> ds_transformer_backward(int layer_id,
layer->GetNumHeads(), layer->GetNumHeads(),
layer->IsTrainingMode(), layer->IsTrainingMode(),
layer->GeluCheckpoint())}, layer->GeluCheckpoint())},
grad_output.options()); options);
Context::Instance().SetWorkSpace((T*)workspace.data_ptr()); Context::Instance().SetWorkSpace((T*)workspace.data_ptr());
auto grad_input = torch::empty_like(input); auto grad_input = torch::empty_like(input);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment