Unverified Commit 4ed69f5d authored by PanZezhong1725's avatar PanZezhong1725 Committed by GitHub
Browse files

Merge pull request #9 from InfiniTensor/pool_alignment

Added memory alignment in the memory pool
parents 3ed7e78f 6fcbb9ec
...@@ -3,7 +3,6 @@ ...@@ -3,7 +3,6 @@
#include "infinicore_infer.h" #include "infinicore_infer.h"
#include <map> #include <map>
#include <memory>
#include <set> #include <set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
...@@ -17,28 +16,35 @@ public: ...@@ -17,28 +16,35 @@ public:
class MemoryPool : public AllocatorBase { class MemoryPool : public AllocatorBase {
public: public:
MemoryPool(size_t initialSize = 0); static constexpr size_t DEFAULT_ALIGNMENT = 256;
explicit MemoryPool(size_t initialSize = 0, size_t alignment = DEFAULT_ALIGNMENT);
~MemoryPool(); ~MemoryPool();
void *alloc(size_t size) override; void *alloc(size_t size) override;
void release(void *ptr) override; void release(void *ptr) override;
size_t getAlignment() const { return _alignment; }
private: private:
struct Block { struct Block {
void *base; void *base;
void *ptr; void *ptr;
size_t size; size_t size;
bool is_free; bool is_free;
Block(void *b, void *p, size_t s, bool f) Block(void *b, void *p, size_t s, bool f)
: base(b), ptr(p), size(s), is_free(f) {} : base(b), ptr(p), size(s), is_free(f) {}
bool operator<(const Block &other) const { bool operator<(const Block &other) const {
return ptr < other.ptr; return ptr < other.ptr;
} }
}; };
void *allocateNewRegion(size_t size); void *allocateNewRegion(size_t size);
void insertFreeBlock(Block &&block);
void tryCoalesce(const Block &block); void tryCoalesce(const Block &block);
size_t _alignment;
std::vector<void *> _base_regions; std::vector<void *> _base_regions;
std::set<Block> _all_blocks; std::set<Block> _all_blocks;
std::multimap<size_t, std::set<Block>::iterator> _free_blocks; std::multimap<size_t, std::set<Block>::iterator> _free_blocks;
......
#include "../allocator.hpp" #include "../allocator.hpp"
#include "../utils.hpp" #include "../utils.hpp"
#include <algorithm>
#include <iostream>
#include <stdexcept>
MemoryPool::MemoryPool(size_t initialSize) { MemoryPool::MemoryPool(size_t initialSize, size_t alignment)
: _alignment(alignment) {
// Validate alignment is power of two
if ((alignment & (alignment - 1)) != 0 || alignment == 0) {
throw std::invalid_argument("Alignment must be a power of two");
}
if (initialSize > 0) {
allocateNewRegion(initialSize); allocateNewRegion(initialSize);
}
} }
MemoryPool::~MemoryPool() { MemoryPool::~MemoryPool() {
...@@ -19,10 +24,14 @@ void *MemoryPool::alloc(size_t size) { ...@@ -19,10 +24,14 @@ void *MemoryPool::alloc(size_t size) {
return nullptr; return nullptr;
} }
auto it = _free_blocks.lower_bound(size); // Calculate aligned size
const size_t aligned_size = (size + _alignment - 1) & ~(_alignment - 1);
// Find the first block with enough space (after alignment)
auto it = _free_blocks.lower_bound(aligned_size);
if (it == _free_blocks.end()) { if (it == _free_blocks.end()) {
allocateNewRegion(size); allocateNewRegion(aligned_size);
it = _free_blocks.lower_bound(size); it = _free_blocks.lower_bound(aligned_size);
if (it == _free_blocks.end()) { if (it == _free_blocks.end()) {
throw std::bad_alloc(); throw std::bad_alloc();
} }
...@@ -33,25 +42,26 @@ void *MemoryPool::alloc(size_t size) { ...@@ -33,25 +42,26 @@ void *MemoryPool::alloc(size_t size) {
_free_blocks.erase(it); _free_blocks.erase(it);
_all_blocks.erase(block_it); _all_blocks.erase(block_it);
if (block.size > size + 256) { // Align the pointer within the block
// Split size_t alignment_padding = reinterpret_cast<char *>(block.ptr) - reinterpret_cast<char *>(block.ptr);
void *alloc_ptr = block.ptr;
void *rem_ptr = static_cast<char *>(block.ptr) + size; // Calculate remaining space after allocation
size_t rem_size = block.size - size; const size_t remaining = block.size - aligned_size - alignment_padding;
Block alloc_block(block.base, alloc_ptr, size, false);
Block rem_block(block.base, rem_ptr, rem_size, true); // Create allocated block
Block alloc_block(block.base, block.ptr, aligned_size, false);
auto alloc_it = _all_blocks.insert(alloc_block).first; auto alloc_it = _all_blocks.insert(alloc_block).first;
auto rem_it = _all_blocks.insert(rem_block).first;
_free_blocks.emplace(rem_size, rem_it);
_ptr_to_block[alloc_ptr] = alloc_it;
return alloc_ptr;
} else {
// No split
block.is_free = false;
auto alloc_it = _all_blocks.insert(block).first;
_ptr_to_block[block.ptr] = alloc_it; _ptr_to_block[block.ptr] = alloc_it;
return block.ptr;
// Split remaining space if it's large enough
if (remaining >= _alignment) {
void *rem_ptr = static_cast<char *>(block.ptr) + aligned_size;
Block rem_block(block.base, rem_ptr, remaining, true);
auto rem_it = _all_blocks.insert(rem_block).first;
_free_blocks.emplace(remaining, rem_it);
} }
return block.ptr;
} }
void MemoryPool::release(void *ptr) { void MemoryPool::release(void *ptr) {
...@@ -74,12 +84,19 @@ void MemoryPool::release(void *ptr) { ...@@ -74,12 +84,19 @@ void MemoryPool::release(void *ptr) {
} }
void *MemoryPool::allocateNewRegion(size_t size) { void *MemoryPool::allocateNewRegion(size_t size) {
// Allocate exactly the requested size
void *ptr = nullptr; void *ptr = nullptr;
RUN_INFINI(infinirtMalloc(&ptr, size)); RUN_INFINI(infinirtMalloc(&ptr, size));
_base_regions.push_back(ptr); _base_regions.push_back(ptr);
Block new_block(ptr, ptr, size, true);
// Align the pointer within the allocated region
size_t alignment_padding = reinterpret_cast<char *>(ptr) - reinterpret_cast<char *>(ptr);
size_t usable_size = size - alignment_padding;
Block new_block(ptr, ptr, usable_size, true);
auto it = _all_blocks.insert(new_block).first; auto it = _all_blocks.insert(new_block).first;
_free_blocks.emplace(size, it); _free_blocks.emplace(usable_size, it);
return ptr; return ptr;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment