Unverified Commit 4ed69f5d authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #9 from InfiniTensor/pool_alignment

Added memory alignment in the memory pool
parents 3ed7e78f 6fcbb9ec
......@@ -3,7 +3,6 @@
#include "infinicore_infer.h"
#include <map>
#include <memory>
#include <set>
#include <unordered_map>
#include <vector>
......@@ -17,28 +16,35 @@ public:
class MemoryPool : public AllocatorBase {
public:
MemoryPool(size_t initialSize = 0);
static constexpr size_t DEFAULT_ALIGNMENT = 256;
explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT);
~MemoryPool();
void *alloc(size_t size) override;
void release(void *ptr) override;
size_t getAlignment() const { return _alignment; }
private:
struct Block {
void *base;
void *ptr;
size_t size;
bool is_free;
Block(void *b, void *p, size_t s, bool f)
: base(b), ptr(p), size(s), is_free(f) {}
bool operator<(const Block &other) const {
return ptr < other.ptr;
}
};
void *allocateNewRegion(size_t size);
void insertFreeBlock(Block &&block);
void tryCoalesce(const Block &block);
size_t _alignment;
std::vector<void *> _base_regions;
std::set<Block> _all_blocks;
std::multimap<size_t, std::set<Block>::iterator> _free_blocks;
......
#include "../allocator.hpp"
#include "../utils.hpp"
#include <algorithm>
#include <iostream>
#include <stdexcept>
MemoryPool::MemoryPool(size_t initialSize) {
MemoryPool::MemoryPool(size_t initialSize, size_t alignment)
: _alignment(alignment) {
// Validate alignment is power of two
if ((alignment & (alignment - 1)) != 0 || alignment == 0) {
throw std::invalid_argument("Alignment must be a power of two");
}
if (initialSize > 0) {
allocateNewRegion(initialSize);
}
}
MemoryPool::~MemoryPool() {
......@@ -19,10 +24,14 @@ void *MemoryPool::alloc(size_t size) {
return nullptr;
}
auto it = _free_blocks.lower_bound(size);
// Calculate aligned size
const size_t aligned_size = (size + _alignment - 1) & ~(_alignment - 1);
// Find the first block with enough space (after alignment)
auto it = _free_blocks.lower_bound(aligned_size);
if (it == _free_blocks.end()) {
allocateNewRegion(size);
it = _free_blocks.lower_bound(size);
allocateNewRegion(aligned_size);
it = _free_blocks.lower_bound(aligned_size);
if (it == _free_blocks.end()) {
throw std::bad_alloc();
}
......@@ -33,25 +42,26 @@ void *MemoryPool::alloc(size_t size) {
_free_blocks.erase(it);
_all_blocks.erase(block_it);
if (block.size > size + 256) {
// Split
void *alloc_ptr = block.ptr;
void *rem_ptr = static_cast<char *>(block.ptr) + size;
size_t rem_size = block.size - size;
Block alloc_block(block.base, alloc_ptr, size, false);
Block rem_block(block.base, rem_ptr, rem_size, true);
// Align the pointer within the block
size_t alignment_padding = reinterpret_cast<char *>(block.ptr) - reinterpret_cast<char *>(block.ptr);
// Calculate remaining space after allocation
const size_t remaining = block.size - aligned_size - alignment_padding;
// Create allocated block
Block alloc_block(block.base, block.ptr, aligned_size, false);
auto alloc_it = _all_blocks.insert(alloc_block).first;
auto rem_it = _all_blocks.insert(rem_block).first;
_free_blocks.emplace(rem_size, rem_it);
_ptr_to_block[alloc_ptr] = alloc_it;
return alloc_ptr;
} else {
// No split
block.is_free = false;
auto alloc_it = _all_blocks.insert(block).first;
_ptr_to_block[block.ptr] = alloc_it;
return block.ptr;
// Split remaining space if it's large enough
if (remaining >= _alignment) {
void *rem_ptr = static_cast<char *>(block.ptr) + aligned_size;
Block rem_block(block.base, rem_ptr, remaining, true);
auto rem_it = _all_blocks.insert(rem_block).first;
_free_blocks.emplace(remaining, rem_it);
}
return block.ptr;
}
void MemoryPool::release(void *ptr) {
......@@ -74,12 +84,19 @@ void MemoryPool::release(void *ptr) {
}
void *MemoryPool::allocateNewRegion(size_t size) {
// Allocate exactly the requested size
void *ptr = nullptr;
RUN_INFINI(infinirtMalloc(&ptr, size));
_base_regions.push_back(ptr);
Block new_block(ptr, ptr, size, true);
// Align the pointer within the allocated region
size_t alignment_padding = reinterpret_cast<char *>(ptr) - reinterpret_cast<char *>(ptr);
size_t usable_size = size - alignment_padding;
Block new_block(ptr, ptr, usable_size, true);
auto it = _all_blocks.insert(new_block).first;
_free_blocks.emplace(size, it);
_free_blocks.emplace(usable_size, it);
return ptr;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment