Unverified Commit 261330bb authored by Zeyu WANG's avatar Zeyu WANG Committed by GitHub
Browse files

fix calc space bug (#91)

* fix calc space bug

* use python code to allocate the buffer for backward kernel
parent eb758335
......@@ -31,40 +31,3 @@ struct cutlass_dtype<__nv_fp8_e5m2> {
template <typename T>
using cutlass_dtype_t = typename cutlass_dtype<T>::type;
\ No newline at end of file
template<typename T>
struct DeviceAllocation {
T* ptr_ = nullptr;
size_t offset_ = 0;
size_t size_ = 0;
torch::Tensor tensor;
DeviceAllocation(DeviceAllocation const&) = delete;
DeviceAllocation& operator=(DeviceAllocation const&) = delete;
DeviceAllocation() = default;
DeviceAllocation(size_t size) { reset(size); }
~DeviceAllocation() {}
void reset(size_t size, size_t offset=0) {
size_t num_element = sizeof(T) * (size + offset);
auto options = torch::TensorOptions().dtype(torch::kByte).device(torch::kCUDA);
tensor = torch::empty(num_element, options);
ptr_ = tensor.data_ptr<T>();
size_ = size;
offset_ = offset;
}
T* get() {
return ptr_ + offset_;
}
const T* get() const {
return ptr_ + offset_;
}
size_t size() const { return size_; }
size_t get_storage_size() const { return (size_ + offset_) * sizeof(T); }
};
......@@ -225,11 +225,11 @@ public:
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
size_t workspace_bytes = 0;
// OdO vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;
// scaled LSE vector
workspace_bytes += B*H*Q * sizeof(ElementAccumulator);
workspace_bytes += sizeof(ElementAccumulator) * B*H*Q;
// FP32 versions of outputs that are churned (start off with Q only)
workspace_bytes += B*H*Q*D * sizeof(ElementAccumulator);
workspace_bytes += sizeof(ElementAccumulator) * B*H*Q*D;
return workspace_bytes;
}
......@@ -247,7 +247,7 @@ public:
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_scaled_lse);
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_dQ);
params_.dQ_acc = dQ_acc;
params_.dQ_acc_size = B*H*Q*D * sizeof(ElementAccumulator);
params_.dQ_acc_size = sizeof(ElementAccumulator) * B*H*Q*D;
auto args_sum_OdO = to_sum_OdO_arguments(args, sum_OdO, scaled_lse);
auto args_convert = to_convert_arguments(args, dQ_acc);
params_.op_sum_OdO.initialize(args_sum_OdO, nullptr, stream);
......@@ -274,9 +274,9 @@ public:
int Q = cutlass::round_up(static_cast<int>(Q_), 8); // Alignment
char* workspace_chr = reinterpret_cast<char*>(workspace);
ElementAccumulator* sum_OdO = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
workspace_chr += sizeof(ElementAccumulator) * B*H*Q;
ElementAccumulator* scaled_lse = reinterpret_cast<ElementAccumulator*>(workspace_chr);
workspace_chr += B*H*Q * sizeof(ElementAccumulator);
workspace_chr += sizeof(ElementAccumulator) * B*H*Q;
ElementAccumulator* dQ_acc = reinterpret_cast<ElementAccumulator*>(workspace_chr);
return initialize_split(args, dQ_acc, sum_OdO, scaled_lse, stream);
}
......
......@@ -174,13 +174,10 @@ struct BwdRunner {
Operation op;
size_t workspace_size = 0;
workspace_size = Operation::get_workspace_size(arguments);
DeviceAllocation<uint8_t> workspace(workspace_size);
uint8_t* workspace_ptr = workspace.get();
uint8_t* workspace_ptr = static_cast<uint8_t*>(workspace_buffer.data_ptr());
CUTLASS_CHECK(op.can_implement(arguments));
CUTLASS_CHECK(op.initialize(arguments, workspace.get()));
CUTLASS_CHECK(op.initialize(arguments, workspace_ptr));
CUTLASS_CHECK(op.run(at::cuda::getCurrentCUDAStream()));
}
......
......@@ -154,7 +154,7 @@ def _flash_attn_varlen_backward(
max_seqlen_qo_aligned = (max_seqlen_qo + 7) // 8 * 8
bs = cu_seqlens_qo.shape[0] - 1
workspace_bytes = 0
workspace_bytes += 4 * qo_total_len * num_qo_heads * head_dim_qk # dQ_acc
workspace_bytes += 4 * bs * max_seqlen_qo_aligned * num_qo_heads * head_dim_qk # dQ_acc
workspace_bytes += 4 * max_seqlen_qo_aligned * bs * num_qo_heads * 2 # sum_OdO and scaled_lse
if num_qo_heads != num_kv_heads:
workspace_bytes += 2 * kv_total_len * num_qo_heads * (head_dim_qk + head_dim_vo) # dKV_acc
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment