Commit e0c65b0d authored by chenpangpang's avatar chenpangpang
Browse files

Add new file

parent b2d7fb87
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
import time
# 检查JAX是否使用了GPU
print("JAX版本:", jax.__version__)
print("可用设备:", jax.devices())
print("当前使用的设备:", jax.devices()[0])
print("设备类型:", jax.devices()[0].device_kind)
# 1. 基本的GPU矩阵运算示例
def basic_gpu_operations():
print("\n=== 基础GPU矩阵运算 ===")
# 创建大型矩阵(这些会自动在GPU上分配)
key = random.PRNGKey(42)
size = 5000 # 大型矩阵
# 生成随机矩阵
A = random.normal(key, (size, size))
B = random.normal(key, (size, size))
# GPU矩阵乘法
start = time.time()
C = jnp.dot(A, B) # 这会自动在GPU上执行
compute_time = time.time() - start
print(f"矩阵大小: {size}x{size}")
print(f"GPU矩阵乘法耗时: {compute_time:.4f}秒")
# 检查结果
print(f"结果矩阵C的形状: {C.shape}")
print(f"C的平均值: {jnp.mean(C):.6f}")
print(f"C的标准差: {jnp.std(C):.6f}")
return C
# 2. 使用JIT编译加速
def jit_acceleration():
print("\n=== JIT编译加速示例 ===")
# 定义一个复杂函数
def complex_function(x):
return jnp.sum(jnp.sin(x) * jnp.exp(-x**2) + jnp.log1p(jnp.abs(x)))
# 创建输入数据
key = random.PRNGKey(0)
x = random.normal(key, (10000, 1000))
# 不使用JIT
start = time.time()
result_no_jit = complex_function(x)
time_no_jit = time.time() - start
# 使用JIT编译
jitted_function = jit(complex_function)
# 第一次调用会有编译开销
start = time.time()
result_jit = jitted_function(x)
time_first_call = time.time() - start
# 第二次调用(已经编译)
start = time.time()
result_jit = jitted_function(x)
time_second_call = time.time() - start
print(f"不使用JIT: {time_no_jit:.4f}秒")
print(f"JIT首次调用(含编译): {time_first_call:.4f}秒")
print(f"JIT后续调用: {time_second_call:.4f}秒")
print(f"加速比: {time_no_jit/time_second_call:.2f}倍")
return result_jit
# 3. 自动微分和GPU计算
def autodiff_gpu():
print("\n=== GPU上的自动微分 ===")
# 定义损失函数
def loss(params, x, y):
w, b = params
predictions = jnp.dot(x, w) + b
return jnp.mean((predictions - y)**2)
# 生成模拟数据
key = random.PRNGKey(1)
key1, key2, key3 = random.split(key, 3)
n_samples = 10000
n_features = 100
X = random.normal(key1, (n_samples, n_features))
true_w = random.normal(key2, (n_features, 1))
true_b = random.normal(key3, (1,))
Y = jnp.dot(X, true_w) + true_b + 0.1 * random.normal(key1, (n_samples, 1))
# 初始化参数
params = (random.normal(key1, (n_features, 1)), jnp.zeros((1,)))
# 计算梯度的GPU函数
grad_loss = jit(grad(loss, argnums=0))
# 在GPU上计算梯度
start = time.time()
gradients = grad_loss(params, X, Y)
grad_time = time.time() - start
print(f"数据形状: X={X.shape}, Y={Y.shape}")
print(f"GPU梯度计算耗时: {grad_time:.4f}秒")
print(f"权重梯度形状: {gradients[0].shape}")
print(f"偏置梯度形状: {gradients[1].shape}")
return gradients
# 4. 批量处理示例
def batch_processing():
print("\n=== GPU批量处理 ===")
# 定义一个处理单个样本的函数
def process_sample(x):
# 一些复杂的变换
return jnp.sum(jnp.tanh(x) * jnp.cos(x))
# 使用vmap进行向量化
batched_process = vmap(process_sample)
# 创建批量数据
key = random.PRNGKey(3)
batch_size = 10000
sample_size = 1000
batch_data = random.normal(key, (batch_size, sample_size))
# 逐个处理(慢)
def sequential_process(data):
results = []
for i in range(data.shape[0]):
results.append(process_sample(data[i]))
return jnp.array(results)
# 计时
start = time.time()
sequential_results = sequential_process(batch_data)
sequential_time = time.time() - start
# 使用vmap批量处理(在GPU上并行)
start = time.time()
batched_results = batched_process(batch_data)
batched_time = time.time() - start
print(f"批量大小: {batch_size}, 样本大小: {sample_size}")
print(f"逐个处理耗时: {sequential_time:.4f}秒")
print(f"批量处理耗时: {batched_time:.4f}秒")
print(f"加速比: {sequential_time/batched_time:.2f}倍")
# 验证结果一致性
error = jnp.max(jnp.abs(sequential_results - batched_results))
print(f"结果最大误差: {error:.10f}")
return batched_results
# 5. 高级示例:简单的神经网络层
def neural_network_example():
print("\n=== GPU上的神经网络层 ===")
# 定义一个简单的神经网络层
def dense_layer(params, x):
w, b = params
return jnp.dot(x, w) + b
# ReLU激活函数
def relu(x):
return jnp.maximum(0, x)
# 两层神经网络
def two_layer_net(params, x):
w1, b1, w2, b2 = params
hidden = relu(jnp.dot(x, w1) + b1)
output = jnp.dot(hidden, w2) + b2
return output
# 初始化参数
key = random.PRNGKey(4)
key1, key2, key3, key4 = random.split(key, 4)
input_size = 784 # MNIST图像大小
hidden_size = 128
output_size = 10 # 10个类别
# 初始化权重
w1 = random.normal(key1, (input_size, hidden_size)) * jnp.sqrt(2.0 / input_size)
b1 = jnp.zeros(hidden_size)
w2 = random.normal(key2, (hidden_size, output_size)) * jnp.sqrt(2.0 / hidden_size)
b2 = jnp.zeros(output_size)
params = (w1, b1, w2, b2)
# 创建批量输入
batch_size = 1024
x_batch = random.normal(key3, (batch_size, input_size))
# JIT编译神经网络
jitted_net = jit(two_layer_net)
# 前向传播
start = time.time()
outputs = jitted_net(params, x_batch)
inference_time = time.time() - start
print(f"网络架构: {input_size} -> {hidden_size} -> {output_size}")
print(f"批量大小: {batch_size}")
print(f"GPU推理耗时: {inference_time:.6f}秒")
print(f"输出形状: {outputs.shape}")
print(f"每个样本的平均输出: {jnp.mean(outputs, axis=0)}")
return outputs
def main():
print("="*50)
print("JAX GPU Demo - 开始执行")
print("="*50)
# 运行所有示例
try:
# 示例1: 基础GPU运算
result1 = basic_gpu_operations()
# 示例2: JIT加速
result2 = jit_acceleration()
# 示例3: 自动微分
result3 = autodiff_gpu()
# 示例4: 批量处理
result4 = batch_processing()
# 示例5: 神经网络
result5 = neural_network_example()
print("\n" + "="*50)
print("所有示例执行完成!")
print("="*50)
# 验证GPU确实在工作
print("\n验证GPU使用情况:")
with jax.profiler.trace("/tmp/jax-trace", create_perfetto_link=True):
# 执行一个操作来查看跟踪
test_matrix = jnp.ones((1000, 1000))
test_result = jnp.dot(test_matrix, test_matrix.T)
print(f"跟踪测试完成,结果形状: {test_result.shape}")
except Exception as e:
print(f"执行出错: {e}")
print("\n注意:请确保已经正确安装JAX GPU版本")
print("安装命令: pip install --upgrade 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html")
if __name__ == "__main__":
main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment