Unverified Commit 261330bb authored by Zeyu WANG's avatar Zeyu WANG Committed by GitHub
Browse files

fix calc space bug (#91)

* fix calc space bug

* use python code to allocate the buffer for backward kernel
parent eb758335
...@@ -30,41 +30,4 @@ struct cutlass_dtype<__nv_fp8_e5m2> { ...@@ -30,41 +30,4 @@ struct cutlass_dtype<__nv_fp8_e5m2> {
}; };
template <typename T> template <typename T>
using cutlass_dtype_t = typename cutlass_dtype<T>::type; using cutlass_dtype_t = typename cutlass_dtype<T>::type;
\ No newline at end of file
template<typename T>
struct DeviceAllocation {
T* ptr_ = nullptr;
size_t offset_ = 0;
size_t size_ = 0;
torch::Tensor tensor;
DeviceAllocation(DeviceAllocation const&) = delete;
DeviceAllocation& operator=(DeviceAllocation const&) = delete;
DeviceAllocation() = default;
DeviceAllocation(size_t size) { reset(size); }
~DeviceAllocation() {}
void reset(size_t size, size_t offset=0) {
size_t num_element = sizeof(T) * (size + offset);
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
tensor = torch::empty(num_element, options);
ptr_ = tensor.data_ptr<T>();
size_ = size;
offset_ = offset;
}
T* get() {
return ptr_ + offset_;
}
const T* get() const {
return ptr_ + offset_;
}
size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
};
...@@ -225,11 +225,11 @@ public: ...@@ -225,11 +225,11 @@ public:
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
size_t workspace_bytes = 0; size_t workspace_bytes = 0;
// OdO vector // OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator); workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;
// scaled LSE vector // scaled LSE vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator); workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;
// FP32 versions of outputs that are churned (start off with Q only) // FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator); workspace_bytes += sizeof(ElementAccumulator) * B*H*Q*D;
return workspace_bytes; return workspace_bytes;
} }
...@@ -247,7 +247,7 @@ public: ...@@ -247,7 +247,7 @@ public:
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse); ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ); ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
params_.dQ_acc = dQ_acc; params_.dQ_acc = dQ_acc;
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator); params_.dQ_acc_size = sizeof(ElementAccumulator) * B*H*Q*D;
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse); auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
auto args_convert = to_convert_arguments(args, dQ_acc); auto args_convert = to_convert_arguments(args, dQ_acc);
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream); params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
...@@ -274,9 +274,9 @@ public: ...@@ -274,9 +274,9 @@ public:
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace); char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr); ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator); workspace_chr += sizeof(ElementAccumulator) * B*H*Q;
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr); ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator); workspace_chr += sizeof(ElementAccumulator) * B*H*Q;
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr); ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream); return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
} }
......
...@@ -174,13 +174,10 @@ struct BwdRunner { ...@@ -174,13 +174,10 @@ struct BwdRunner {
Operation op; Operation op;
size_t workspace_size = 0; uint8_t* workspace_ptr = static_cast<uint8_t*>(workspace_buffer.data_ptr());
workspace_size = Operation::get_workspace_size(arguments);
DeviceAllocation<uint8_t> workspace(workspace_size);
uint8_t* workspace_ptr = workspace.get();
CUTLASS_CHECK(op.can_implement(arguments)); CUTLASS_CHECK(op.can_implement(arguments));
CUTLASS_CHECK(op.initialize(arguments, workspace.get())); CUTLASS_CHECK(op.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream())); CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
} }
......
...@@ -154,7 +154,7 @@ def _flash_attn_varlen_backward( ...@@ -154,7 +154,7 @@ def _flash_attn_varlen_backward(
max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8 max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8
bs = cu_seqlens_qo.shape[0] - 1 bs = cu_seqlens_qo.shape[0] - 1
workspace_bytes = 0 workspace_bytes = 0
workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc
workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse
if num_qo_heads != num_kv_heads: if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment