Commit 16a80bc6 authored by wooway777's avatar wooway777
Browse files

issue/730 - optimized debug

parent 11aa0c14
...@@ -114,32 +114,10 @@ void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, co ...@@ -114,32 +114,10 @@ void write_binary_data(std::ofstream &out, const T *data, const Shape &shape, co
void TensorImpl::debug(const std::string &filename) const { void TensorImpl::debug(const std::string &filename) const {
// Synchronize device if needed // Synchronize device if needed
context::syncDevice(); context::syncDevice();
std::cout << info() << std::endl; std::cout << info() << std::endl;
const std::byte *cpu_data = nullptr;
std::unique_ptr<std::byte[]> allocated_memory; // RAII: 自动管理内存 std::unique_ptr<std::byte[]> allocated_memory; // RAII: 自动管理内存
auto cpu_tensor = this->contiguous()->to(Device::cpu());
// Copy data to CPU if not already on CPU const std::byte *cpu_data = cpu_tensor->data();
if (this->device().getType() != Device::Type::CPU) {
size_t numel = this->numel();
size_t element_size = dsize(this->dtype());
// 检查乘法溢出
if (numel > 0 && element_size > std::numeric_limits<size_t>::max() / numel) {
std::cerr << "Error: Memory size calculation overflow for tensor with "
<< numel << " elements of size " << element_size << "\n";
return;
}
size_t mem_size = numel * element_size;
allocated_memory = std::make_unique<std::byte[]>(mem_size);
context::memcpyD2H(allocated_memory.get(), this->data(), mem_size);
cpu_data = allocated_memory.get();
} else {
cpu_data = this->data();
}
// If filename is provided, save to binary file // If filename is provided, save to binary file
if (!filename.empty()) { if (!filename.empty()) {
std::ofstream outFile(filename, std::ios::binary); std::ofstream outFile(filename, std::ios::binary);
...@@ -147,139 +125,72 @@ void TensorImpl::debug(const std::string &filename) const { ...@@ -147,139 +125,72 @@ void TensorImpl::debug(const std::string &filename) const {
std::cerr << "Error opening file for writing: " << filename << "\n"; std::cerr << "Error opening file for writing: " << filename << "\n";
return; // allocated_memory 会自动释放(RAII) return; // allocated_memory 会自动释放(RAII)
} }
// Fast path: contiguous tensor, write in one go
// Check if tensor is contiguous - for optimization size_t mem_size = cpu_tensor->numel() * dsize(cpu_tensor->dtype());
if (this->is_contiguous()) { outFile.write(reinterpret_cast<const char *>(cpu_data), mem_size);
// Fast path: contiguous tensor, write in one go
size_t mem_size = this->numel() * dsize(this->dtype());
outFile.write(reinterpret_cast<const char *>(cpu_data), mem_size);
} else {
// Slow path: non-contiguous tensor, write element by element using strides
switch (this->dtype()) {
case DataType::F16:
case DataType::BF16:
write_binary_data(outFile, reinterpret_cast<const uint16_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::F32:
write_binary_data(outFile, reinterpret_cast<const float *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::F64:
write_binary_data(outFile, reinterpret_cast<const double *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::U64:
write_binary_data(outFile, reinterpret_cast<const uint64_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::I64:
write_binary_data(outFile, reinterpret_cast<const int64_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::U32:
write_binary_data(outFile, reinterpret_cast<const uint32_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::I32:
write_binary_data(outFile, reinterpret_cast<const int32_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::U16:
write_binary_data(outFile, reinterpret_cast<const uint16_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::I16:
write_binary_data(outFile, reinterpret_cast<const int16_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::U8:
write_binary_data(outFile, reinterpret_cast<const uint8_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::I8:
write_binary_data(outFile, reinterpret_cast<const int8_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
case DataType::BOOL:
// 布尔类型特殊处理:转换为 uint8_t 以保证跨平台一致性
write_binary_data(outFile, reinterpret_cast<const uint8_t *>(cpu_data),
this->shape(), this->strides(), 0);
break;
default:
std::cerr << "Unsupported data type for binary output\n";
return;
}
}
// 显式关闭文件并检查是否成功 // 显式关闭文件并检查是否成功
outFile.close(); outFile.close();
if (!outFile) { if (!outFile) {
std::cerr << "Error: Failed to write data to file: " << filename << "\n"; std::cerr << "Error: Failed to write data to file: " << filename << "\n";
return; return;
} }
std::cout << "Data written to binary file: " << filename; std::cout << "Data written to binary file: " << filename;
if (!this->is_contiguous()) {
std::cout << " (non-contiguous tensor, wrote " << this->numel() << " elements)";
}
std::cout << "\n"; std::cout << "\n";
return; return;
} }
// Print data based on dtype // Print data based on dtype
switch (this->dtype()) { switch (cpu_tensor->dtype()) {
case DataType::F16: case DataType::F16:
print_data(reinterpret_cast<const uint16_t *>(cpu_data), print_data(reinterpret_cast<const uint16_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::F32: case DataType::F32:
print_data(reinterpret_cast<const float *>(cpu_data), print_data(reinterpret_cast<const float *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::F64: case DataType::F64:
print_data(reinterpret_cast<const double *>(cpu_data), print_data(reinterpret_cast<const double *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::U64: case DataType::U64:
print_data(reinterpret_cast<const uint64_t *>(cpu_data), print_data(reinterpret_cast<const uint64_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::I64: case DataType::I64:
print_data(reinterpret_cast<const int64_t *>(cpu_data), print_data(reinterpret_cast<const int64_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::U32: case DataType::U32:
print_data(reinterpret_cast<const uint32_t *>(cpu_data), print_data(reinterpret_cast<const uint32_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::I32: case DataType::I32:
print_data(reinterpret_cast<const int32_t *>(cpu_data), print_data(reinterpret_cast<const int32_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::U16: case DataType::U16:
print_data(reinterpret_cast<const uint16_t *>(cpu_data), print_data(reinterpret_cast<const uint16_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::I16: case DataType::I16:
print_data(reinterpret_cast<const int16_t *>(cpu_data), print_data(reinterpret_cast<const int16_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::U8: case DataType::U8:
print_data(reinterpret_cast<const uint8_t *>(cpu_data), print_data(reinterpret_cast<const uint8_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::I8: case DataType::I8:
print_data(reinterpret_cast<const int8_t *>(cpu_data), print_data(reinterpret_cast<const int8_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::BF16: case DataType::BF16:
print_data_bf16(reinterpret_cast<const uint16_t *>(cpu_data), print_data_bf16(reinterpret_cast<const uint16_t *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
case DataType::BOOL: case DataType::BOOL:
print_data(reinterpret_cast<const bool *>(cpu_data), print_data(reinterpret_cast<const bool *>(cpu_data),
this->shape(), this->strides(), 0); cpu_tensor->shape(), cpu_tensor->strides(), 0);
break; break;
default: default:
std::cout << "Unsupported data type for debug" << std::endl; std::cout << "Unsupported data type for debug" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment